http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/optimizer/LossFunctions.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/optimizer/LossFunctions.java b/core/src/main/java/hivemall/optimizer/LossFunctions.java new file mode 100644 index 0000000..d11be9b --- /dev/null +++ b/core/src/main/java/hivemall/optimizer/LossFunctions.java @@ -0,0 +1,467 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2015 Makoto YUI + * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package hivemall.optimizer; + +import hivemall.utils.math.MathUtils; + +/** + * @link https://github.com/JohnLangford/vowpal_wabbit/wiki/Loss-functions + */ +public final class LossFunctions { + + public enum LossType { + SquaredLoss, LogLoss, HingeLoss, SquaredHingeLoss, QuantileLoss, EpsilonInsensitiveLoss + } + + public static LossFunction getLossFunction(String type) { + if ("SquaredLoss".equalsIgnoreCase(type)) { + return new SquaredLoss(); + } else if ("LogLoss".equalsIgnoreCase(type)) { + return new LogLoss(); + } else if ("HingeLoss".equalsIgnoreCase(type)) { + return new HingeLoss(); + } else if ("SquaredHingeLoss".equalsIgnoreCase(type)) { + return new SquaredHingeLoss(); + } else if ("QuantileLoss".equalsIgnoreCase(type)) { + return new QuantileLoss(); + } else if ("EpsilonInsensitiveLoss".equalsIgnoreCase(type)) { + return new EpsilonInsensitiveLoss(); + } + throw new IllegalArgumentException("Unsupported type: " + type); + } + + public static LossFunction getLossFunction(LossType type) { + switch (type) { + case SquaredLoss: + return new SquaredLoss(); + case LogLoss: + return new LogLoss(); + case HingeLoss: + return new HingeLoss(); + case SquaredHingeLoss: + return new SquaredHingeLoss(); + case QuantileLoss: + return new QuantileLoss(); + case EpsilonInsensitiveLoss: + return new EpsilonInsensitiveLoss(); + default: + throw new IllegalArgumentException("Unsupported type: " + type); + } + } + + public interface LossFunction { + + /** + * Evaluate the loss function. + * + * @param p The prediction, p = w^T x + * @param y The true value (aka target) + * @return The loss evaluated at `p` and `y`. + */ + public float loss(float p, float y); + + public double loss(double p, double y); + + /** + * Evaluate the derivative of the loss function with respect to the prediction `p`. + * + * @param p The prediction, p = w^T x + * @param y The true value (aka target) + * @return The derivative of the loss function w.r.t. `p`. + */ + public float dloss(float p, float y); + + public boolean forBinaryClassification(); + + public boolean forRegression(); + + } + + public static abstract class BinaryLoss implements LossFunction { + + protected static void checkTarget(float y) { + if (!(y == 1.f || y == -1.f)) { + throw new IllegalArgumentException("target must be [+1,-1]: " + y); + } + } + + protected static void checkTarget(double y) { + if (!(y == 1.d || y == -1.d)) { + throw new IllegalArgumentException("target must be [+1,-1]: " + y); + } + } + + @Override + public boolean forBinaryClassification() { + return true; + } + + @Override + public boolean forRegression() { + return false; + } + } + + public static abstract class RegressionLoss implements LossFunction { + + @Override + public boolean forBinaryClassification() { + return false; + } + + @Override + public boolean forRegression() { + return true; + } + + } + + /** + * Squared loss for regression problems. + * + * If you're trying to minimize the mean error, use squared-loss. + */ + public static final class SquaredLoss extends RegressionLoss { + + @Override + public float loss(float p, float y) { + final float z = p - y; + return z * z * 0.5f; + } + + @Override + public double loss(double p, double y) { + final double z = p - y; + return z * z * 0.5d; + } + + @Override + public float dloss(float p, float y) { + return p - y; // 2 (p - y) / 2 + } + } + + /** + * Logistic regression loss for binary classification with y in {-1, 1}. + */ + public static final class LogLoss extends BinaryLoss { + + /** + * <code>logloss(p,y) = log(1+exp(-p*y))</code> + */ + @Override + public float loss(float p, float y) { + checkTarget(y); + + final float z = y * p; + if (z > 18.f) { + return (float) Math.exp(-z); + } + if (z < -18.f) { + return -z; + } + return (float) Math.log(1.d + Math.exp(-z)); + } + + @Override + public double loss(double p, double y) { + checkTarget(y); + + final double z = y * p; + if (z > 18.d) { + return Math.exp(-z); + } + if (z < -18.d) { + return -z; + } + return Math.log(1.d + Math.exp(-z)); + } + + @Override + public float dloss(float p, float y) { + checkTarget(y); + + float z = y * p; + if (z > 18.f) { + return (float) Math.exp(-z) * -y; + } + if (z < -18.f) { + return -y; + } + return -y / ((float) Math.exp(z) + 1.f); + } + } + + /** + * Hinge loss for binary classification tasks with y in {-1,1}. + */ + public static final class HingeLoss extends BinaryLoss { + + private float threshold; + + public HingeLoss() { + this(1.f); + } + + /** + * @param threshold Margin threshold. When threshold=1.0, one gets the loss used by SVM. + * When threshold=0.0, one gets the loss used by the Perceptron. + */ + public HingeLoss(float threshold) { + this.threshold = threshold; + } + + public void setThreshold(float threshold) { + this.threshold = threshold; + } + + @Override + public float loss(float p, float y) { + float loss = hingeLoss(p, y, threshold); + return (loss > 0.f) ? loss : 0.f; + } + + @Override + public double loss(double p, double y) { + double loss = hingeLoss(p, y, threshold); + return (loss > 0.d) ? loss : 0.d; + } + + @Override + public float dloss(float p, float y) { + float loss = hingeLoss(p, y, threshold); + return (loss > 0.f) ? -y : 0.f; + } + } + + /** + * Squared Hinge loss for binary classification tasks with y in {-1,1}. + */ + public static final class SquaredHingeLoss extends BinaryLoss { + + @Override + public float loss(float p, float y) { + return squaredHingeLoss(p, y); + } + + @Override + public double loss(double p, double y) { + return squaredHingeLoss(p, y); + } + + @Override + public float dloss(float p, float y) { + checkTarget(y); + + float d = 1 - (y * p); + return (d > 0.f) ? -2.f * d * y : 0.f; + } + + } + + /** + * Quantile loss is useful to predict rank/order and you do not mind the mean error to increase + * as long as you get the relative order correct. + * + * @link http://en.wikipedia.org/wiki/Quantile_regression + */ + public static final class QuantileLoss extends RegressionLoss { + + private float tau; + + public QuantileLoss() { + this.tau = 0.5f; + } + + public QuantileLoss(float tau) { + setTau(tau); + } + + public void setTau(float tau) { + if (tau <= 0 || tau >= 1.0) { + throw new IllegalArgumentException("tau must be in range (0, 1): " + tau); + } + this.tau = tau; + } + + @Override + public float loss(float p, float y) { + float e = y - p; + if (e > 0.f) { + return tau * e; + } else { + return -(1.f - tau) * e; + } + } + + @Override + public double loss(double p, double y) { + double e = y - p; + if (e > 0.d) { + return tau * e; + } else { + return -(1.d - tau) * e; + } + } + + @Override + public float dloss(float p, float y) { + float e = y - p; + if (e == 0.f) { + return 0.f; + } + return (e > 0.f) ? -tau : (1.f - tau); + } + + } + + /** + * Epsilon-Insensitive loss used by Support Vector Regression (SVR). + * <code>loss = max(0, |y - p| - epsilon)</code> + */ + public static final class EpsilonInsensitiveLoss extends RegressionLoss { + + private float epsilon; + + public EpsilonInsensitiveLoss() { + this(0.1f); + } + + public EpsilonInsensitiveLoss(float epsilon) { + this.epsilon = epsilon; + } + + public void setEpsilon(float epsilon) { + this.epsilon = epsilon; + } + + @Override + public float loss(float p, float y) { + float loss = Math.abs(y - p) - epsilon; + return (loss > 0.f) ? loss : 0.f; + } + + @Override + public double loss(double p, double y) { + double loss = Math.abs(y - p) - epsilon; + return (loss > 0.d) ? loss : 0.d; + } + + @Override + public float dloss(float p, float y) { + if ((y - p) > epsilon) {// real value > predicted value - epsilon + return -1.f; + } + if ((p - y) > epsilon) {// real value < predicted value - epsilon + return 1.f; + } + return 0.f; + } + + } + + public static float logisticLoss(final float target, final float predicted) { + if (predicted > -100.d) { + return target - (float) MathUtils.sigmoid(predicted); + } else { + return target; + } + } + + public static float logLoss(final float p, final float y) { + BinaryLoss.checkTarget(y); + + final float z = y * p; + if (z > 18.f) { + return (float) Math.exp(-z); + } + if (z < -18.f) { + return -z; + } + return (float) Math.log(1.d + Math.exp(-z)); + } + + public static double logLoss(final double p, final double y) { + BinaryLoss.checkTarget(y); + + final double z = y * p; + if (z > 18.d) { + return Math.exp(-z); + } + if (z < -18.d) { + return -z; + } + return Math.log(1.d + Math.exp(-z)); + } + + public static float squaredLoss(float p, float y) { + final float z = p - y; + return z * z * 0.5f; + } + + public static double squaredLoss(double p, double y) { + final double z = p - y; + return z * z * 0.5d; + } + + public static float hingeLoss(final float p, final float y, final float threshold) { + BinaryLoss.checkTarget(y); + + float z = y * p; + return threshold - z; + } + + public static double hingeLoss(final double p, final double y, final double threshold) { + BinaryLoss.checkTarget(y); + + double z = y * p; + return threshold - z; + } + + public static float hingeLoss(float p, float y) { + return hingeLoss(p, y, 1.f); + } + + public static double hingeLoss(double p, double y) { + return hingeLoss(p, y, 1.d); + } + + public static float squaredHingeLoss(final float p, final float y) { + BinaryLoss.checkTarget(y); + + float z = y * p; + float d = 1.f - z; + return (d > 0.f) ? (d * d) : 0.f; + } + + public static double squaredHingeLoss(final double p, final double y) { + BinaryLoss.checkTarget(y); + + double z = y * p; + double d = 1.d - z; + return (d > 0.d) ? d * d : 0.d; + } + + /** + * Math.abs(target - predicted) - epsilon + */ + public static float epsilonInsensitiveLoss(float predicted, float target, float epsilon) { + return Math.abs(target - predicted) - epsilon; + } +}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/optimizer/Optimizer.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/optimizer/Optimizer.java b/core/src/main/java/hivemall/optimizer/Optimizer.java new file mode 100644 index 0000000..863536c --- /dev/null +++ b/core/src/main/java/hivemall/optimizer/Optimizer.java @@ -0,0 +1,246 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2015 Makoto YUI + * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package hivemall.optimizer; + +import java.util.Map; +import javax.annotation.Nonnull; +import javax.annotation.concurrent.NotThreadSafe; + +import hivemall.model.WeightValue; +import hivemall.model.IWeightValue; + +public interface Optimizer { + + /** + * Update the weights of models thru this interface. + */ + float computeUpdatedValue(@Nonnull Object feature, float weight, float gradient); + + // Count up #step to tune learning rate + void proceedStep(); + + static abstract class OptimizerBase implements Optimizer { + + protected final EtaEstimator etaImpl; + protected final Regularization regImpl; + + protected int numStep = 1; + + public OptimizerBase(final Map<String, String> options) { + this.etaImpl = EtaEstimator.get(options); + this.regImpl = Regularization.get(options); + } + + @Override + public void proceedStep() { + numStep++; + } + + // Directly update a given `weight` in terms of performance + protected void computeUpdateValue( + @Nonnull final IWeightValue weight, float gradient) { + float delta = computeUpdateValueImpl(weight, regImpl.regularize(weight.get(), gradient)); + weight.set(weight.get() - etaImpl.eta(numStep) * delta); + } + + // Compute a delta to update + protected float computeUpdateValueImpl( + @Nonnull final IWeightValue weight, float gradient) { + return gradient; + } + + } + + @NotThreadSafe + static final class SGD extends OptimizerBase { + + private final IWeightValue weightValueReused; + + public SGD(final Map<String, String> options) { + super(options); + this.weightValueReused = new WeightValue(0.f); + } + + @Override + public float computeUpdatedValue( + @Nonnull Object feature, float weight, float gradient) { + computeUpdateValue(weightValueReused, gradient); + return weightValueReused.get(); + } + + } + + static abstract class AdaDelta extends OptimizerBase { + + private final float decay; + private final float eps; + private final float scale; + + public AdaDelta(Map<String, String> options) { + super(options); + float decay = 0.95f; + float eps = 1e-6f; + float scale = 100.0f; + if(options.containsKey("decay")) { + decay = Float.parseFloat(options.get("decay")); + } + if(options.containsKey("eps")) { + eps = Float.parseFloat(options.get("eps")); + } + if(options.containsKey("scale")) { + scale = Float.parseFloat(options.get("scale")); + } + this.decay = decay; + this.eps = eps; + this.scale = scale; + } + + @Override + protected float computeUpdateValueImpl(@Nonnull final IWeightValue weight, float gradient) { + float old_scaled_sum_sqgrad = weight.getSumOfSquaredGradients(); + float old_sum_squared_delta_x = weight.getSumOfSquaredDeltaX(); + float new_scaled_sum_sqgrad = (decay * old_scaled_sum_sqgrad) + ((1.f - decay) * gradient * (gradient / scale)); + float delta = (float) Math.sqrt((old_sum_squared_delta_x + eps) / (new_scaled_sum_sqgrad * scale + eps)) * gradient; + float new_sum_squared_delta_x = (decay * old_sum_squared_delta_x) + ((1.f - decay) * delta * delta); + weight.setSumOfSquaredGradients(new_scaled_sum_sqgrad); + weight.setSumOfSquaredDeltaX(new_sum_squared_delta_x); + return delta; + } + + } + + static abstract class AdaGrad extends OptimizerBase { + + private final float eps; + private final float scale; + + public AdaGrad(Map<String, String> options) { + super(options); + float eps = 1.0f; + float scale = 100.0f; + if(options.containsKey("eps")) { + eps = Float.parseFloat(options.get("eps")); + } + if(options.containsKey("scale")) { + scale = Float.parseFloat(options.get("scale")); + } + this.eps = eps; + this.scale = scale; + } + + @Override + protected float computeUpdateValueImpl(@Nonnull final IWeightValue weight, float gradient) { + float new_scaled_sum_sqgrad = weight.getSumOfSquaredGradients() + gradient * (gradient / scale); + float delta = gradient / ((float) Math.sqrt(new_scaled_sum_sqgrad * scale) + eps); + weight.setSumOfSquaredGradients(new_scaled_sum_sqgrad); + return delta; + } + + } + + /** + * Adam, an algorithm for first-order gradient-based optimization of stochastic objective + * functions, based on adaptive estimates of lower-order moments. + * + * - D. P. Kingma and J. L. Ba: "ADAM: A Method for Stochastic Optimization." arXiv preprint arXiv:1412.6980v8, 2014. + */ + static abstract class Adam extends OptimizerBase { + + private final float beta; + private final float gamma; + private final float eps_hat; + + public Adam(Map<String, String> options) { + super(options); + float beta = 0.9f; + float gamma = 0.999f; + float eps_hat = 1e-8f; + if(options.containsKey("beta")) { + beta = Float.parseFloat(options.get("beta")); + } + if(options.containsKey("gamma")) { + gamma = Float.parseFloat(options.get("gamma")); + } + if(options.containsKey("eps_hat")) { + eps_hat = Float.parseFloat(options.get("eps_hat")); + } + this.beta = beta; + this.gamma = gamma; + this.eps_hat = eps_hat; + } + + @Override + protected float computeUpdateValueImpl(@Nonnull final IWeightValue weight, float gradient) { + float val_m = beta * weight.getM() + (1.f - beta) * gradient; + float val_v = gamma * weight.getV() + (float) ((1.f - gamma) * Math.pow(gradient, 2.0)); + float val_m_hat = val_m / (float) (1.f - Math.pow(beta, numStep)); + float val_v_hat = val_v / (float) (1.f - Math.pow(gamma, numStep)); + float delta = val_m_hat / (float) (Math.sqrt(val_v_hat) + eps_hat); + weight.setM(val_m); + weight.setV(val_v); + return delta; + } + + } + + static abstract class RDA extends OptimizerBase { + + private final OptimizerBase optimizerImpl; + + private final float lambda; + + public RDA(final OptimizerBase optimizerImpl, Map<String, String> options) { + super(options); + // We assume `optimizerImpl` has the `AdaGrad` implementation only + if(!(optimizerImpl instanceof AdaGrad)) { + throw new IllegalArgumentException( + optimizerImpl.getClass().getSimpleName() + + " currently does not support RDA regularization"); + } + float lambda = 1e-6f; + if(options.containsKey("lambda")) { + lambda = Float.parseFloat(options.get("lambda")); + } + this.optimizerImpl = optimizerImpl; + this.lambda = lambda; + } + + @Override + protected void computeUpdateValue(@Nonnull final IWeightValue weight, float gradient) { + float new_sum_grad = weight.getSumOfGradients() + gradient; + // sign(u_{t,i}) + float sign = (new_sum_grad > 0.f)? 1.f : -1.f; + // |u_{t,i}|/t - \lambda + float meansOfGradients = (sign * new_sum_grad / numStep) - lambda; + if(meansOfGradients < 0.f) { + // x_{t,i} = 0 + weight.set(0.f); + weight.setSumOfSquaredGradients(0.f); + weight.setSumOfGradients(0.f); + } else { + // x_{t,i} = -sign(u_{t,i}) * \frac{\eta t}{\sqrt{G_{t,ii}}}(|u_{t,i}|/t - \lambda) + float new_weight = -1.f * sign * etaImpl.eta(numStep) * numStep * optimizerImpl.computeUpdateValueImpl(weight, meansOfGradients); + weight.set(new_weight); + weight.setSumOfGradients(new_sum_grad); + } + } + + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/optimizer/Regularization.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/optimizer/Regularization.java b/core/src/main/java/hivemall/optimizer/Regularization.java new file mode 100644 index 0000000..ce1ef7f --- /dev/null +++ b/core/src/main/java/hivemall/optimizer/Regularization.java @@ -0,0 +1,99 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2015 Makoto YUI + * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package hivemall.optimizer; + +import javax.annotation.Nonnull; +import java.util.Map; + +public abstract class Regularization { + + protected final float lambda; + + public Regularization(final Map<String, String> options) { + float lambda = 1e-6f; + if(options.containsKey("lambda")) { + lambda = Float.parseFloat(options.get("lambda")); + } + this.lambda = lambda; + } + + abstract float regularize(float weight, float gradient); + + public static final class PassThrough extends Regularization { + + public PassThrough(final Map<String, String> options) { + super(options); + } + + @Override + public float regularize(float weight, float gradient) { + return gradient; + } + + } + + public static final class L1 extends Regularization { + + public L1(Map<String, String> options) { + super(options); + } + + @Override + public float regularize(float weight, float gradient) { + return gradient + lambda * (weight > 0.f? 1.f : -1.f); + } + + } + + public static final class L2 extends Regularization { + + public L2(final Map<String, String> options) { + super(options); + } + + @Override + public float regularize(float weight, float gradient) { + return gradient + lambda * weight; + } + + } + + @Nonnull + public static Regularization get(@Nonnull final Map<String, String> options) + throws IllegalArgumentException { + final String regName = options.get("regularization"); + if (regName == null) { + return new PassThrough(options); + } + if(regName.toLowerCase().equals("no")) { + return new PassThrough(options); + } else if(regName.toLowerCase().equals("l1")) { + return new L1(options); + } else if(regName.toLowerCase().equals("l2")) { + return new L2(options); + } else if(regName.toLowerCase().equals("rda")) { + // Return `PassThrough` because we need special handling for RDA. + // See an implementation of `Optimizer#RDA`. + return new PassThrough(options); + } else { + throw new IllegalArgumentException("Unsupported regularization name: " + regName); + } + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java b/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java new file mode 100644 index 0000000..a74d0da --- /dev/null +++ b/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java @@ -0,0 +1,171 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2015 Makoto YUI + * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package hivemall.optimizer; + +import javax.annotation.Nonnull; +import javax.annotation.concurrent.NotThreadSafe; +import java.util.Map; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import hivemall.optimizer.Optimizer.OptimizerBase; +import hivemall.model.IWeightValue; +import hivemall.model.WeightValue; +import hivemall.utils.collections.OpenHashMap; + +public final class SparseOptimizerFactory { + private static final Log logger = LogFactory.getLog(SparseOptimizerFactory.class); + + @Nonnull + public static Optimizer create(int ndims, @Nonnull Map<String, String> options) { + final String optimizerName = options.get("optimizer"); + if(optimizerName != null) { + OptimizerBase optimizerImpl; + if(optimizerName.toLowerCase().equals("sgd")) { + optimizerImpl = new Optimizer.SGD(options); + } else if(optimizerName.toLowerCase().equals("adadelta")) { + optimizerImpl = new AdaDelta(ndims, options); + } else if(optimizerName.toLowerCase().equals("adagrad")) { + optimizerImpl = new AdaGrad(ndims, options); + } else if(optimizerName.toLowerCase().equals("adam")) { + optimizerImpl = new Adam(ndims, options); + } else { + throw new IllegalArgumentException("Unsupported optimizer name: " + optimizerName); + } + + logger.info("set " + optimizerImpl.getClass().getSimpleName() + + " as an optimizer: " + options); + + // If a regularization type is "RDA", wrap the optimizer with `Optimizer#RDA`. + if(options.get("regularization") != null + && options.get("regularization").toLowerCase().equals("rda")) { + optimizerImpl = new RDA(ndims, optimizerImpl, options); + } + + return optimizerImpl; + } + throw new IllegalArgumentException("`optimizer` not defined"); + } + + @NotThreadSafe + static final class AdaDelta extends Optimizer.AdaDelta { + + private final OpenHashMap<Object, IWeightValue> auxWeights; + + public AdaDelta(int size, Map<String, String> options) { + super(options); + this.auxWeights = new OpenHashMap<Object, IWeightValue>(size); + } + + @Override + public float computeUpdatedValue(@Nonnull Object feature, float weight, float gradient) { + IWeightValue auxWeight; + if(auxWeights.containsKey(feature)) { + auxWeight = auxWeights.get(feature); + auxWeight.set(weight); + } else { + auxWeight = new WeightValue.WeightValueParamsF2(weight, 0.f, 0.f); + auxWeights.put(feature, auxWeight); + } + computeUpdateValue(auxWeight, gradient); + return auxWeight.get(); + } + + } + + @NotThreadSafe + static final class AdaGrad extends Optimizer.AdaGrad { + + private final OpenHashMap<Object, IWeightValue> auxWeights; + + public AdaGrad(int size, Map<String, String> options) { + super(options); + this.auxWeights = new OpenHashMap<Object, IWeightValue>(size); + } + + @Override + public float computeUpdatedValue(@Nonnull Object feature, float weight, float gradient) { + IWeightValue auxWeight; + if(auxWeights.containsKey(feature)) { + auxWeight = auxWeights.get(feature); + auxWeight.set(weight); + } else { + auxWeight = new WeightValue.WeightValueParamsF2(weight, 0.f, 0.f); + auxWeights.put(feature, auxWeight); + } + computeUpdateValue(auxWeight, gradient); + return auxWeight.get(); + } + + } + + @NotThreadSafe + static final class Adam extends Optimizer.Adam { + + private final OpenHashMap<Object, IWeightValue> auxWeights; + + public Adam(int size, Map<String, String> options) { + super(options); + this.auxWeights = new OpenHashMap<Object, IWeightValue>(size); + } + + @Override + public float computeUpdatedValue(@Nonnull Object feature, float weight, float gradient) { + IWeightValue auxWeight; + if(auxWeights.containsKey(feature)) { + auxWeight = auxWeights.get(feature); + auxWeight.set(weight); + } else { + auxWeight = new WeightValue.WeightValueParamsF2(weight, 0.f, 0.f); + auxWeights.put(feature, auxWeight); + } + computeUpdateValue(auxWeight, gradient); + return auxWeight.get(); + } + + } + + @NotThreadSafe + static final class RDA extends Optimizer.RDA { + + private final OpenHashMap<Object, IWeightValue> auxWeights; + + public RDA(int size, OptimizerBase optimizerImpl, Map<String, String> options) { + super(optimizerImpl, options); + this.auxWeights = new OpenHashMap<Object, IWeightValue>(size); + } + + @Override + public float computeUpdatedValue(@Nonnull Object feature, float weight, float gradient) { + IWeightValue auxWeight; + if(auxWeights.containsKey(feature)) { + auxWeight = auxWeights.get(feature); + auxWeight.set(weight); + } else { + auxWeight = new WeightValue.WeightValueParamsF2(weight, 0.f, 0.f); + auxWeights.put(feature, auxWeight); + } + computeUpdateValue(auxWeight, gradient); + return auxWeight.get(); + } + + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/regression/AROWRegressionUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/regression/AROWRegressionUDTF.java b/core/src/main/java/hivemall/regression/AROWRegressionUDTF.java index b81a4bf..0c964c8 100644 --- a/core/src/main/java/hivemall/regression/AROWRegressionUDTF.java +++ b/core/src/main/java/hivemall/regression/AROWRegressionUDTF.java @@ -18,7 +18,7 @@ */ package hivemall.regression; -import hivemall.common.LossFunctions; +import hivemall.optimizer.LossFunctions; import hivemall.common.OnlineVariance; import hivemall.model.FeatureValue; import hivemall.model.IWeightValue; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/regression/AdaDeltaUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/regression/AdaDeltaUDTF.java b/core/src/main/java/hivemall/regression/AdaDeltaUDTF.java index e807340..50dc9b5 100644 --- a/core/src/main/java/hivemall/regression/AdaDeltaUDTF.java +++ b/core/src/main/java/hivemall/regression/AdaDeltaUDTF.java @@ -18,123 +18,14 @@ */ package hivemall.regression; -import hivemall.common.LossFunctions; -import hivemall.model.FeatureValue; -import hivemall.model.IWeightValue; -import hivemall.model.WeightValue.WeightValueParamsF2; -import hivemall.utils.lang.Primitives; - -import javax.annotation.Nonnull; -import javax.annotation.Nullable; - -import org.apache.commons.cli.CommandLine; -import org.apache.commons.cli.Options; -import org.apache.hadoop.hive.ql.exec.Description; -import org.apache.hadoop.hive.ql.exec.UDFArgumentException; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; - /** * ADADELTA: AN ADAPTIVE LEARNING RATE METHOD. */ -@Description( - name = "train_adadelta_regr", - value = "_FUNC_(array<int|bigint|string> features, float target [, constant string options])" - + " - Returns a relation consists of <{int|bigint|string} feature, float weight>") -public final class AdaDeltaUDTF extends RegressionBaseUDTF { - - private float decay; - private float eps; - private float scaling; - - @Override - public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { - final int numArgs = argOIs.length; - if (numArgs != 2 && numArgs != 3) { - throw new UDFArgumentException( - "AdaDeltaUDTF takes 2 or 3 arguments: List<Text|Int|BitInt> features, float target [, constant string options]"); - } - - StructObjectInspector oi = super.initialize(argOIs); - model.configureParams(true, true, false); - return oi; - } - - @Override - protected Options getOptions() { - Options opts = super.getOptions(); - opts.addOption("rho", "decay", true, "Decay rate [default 0.95]"); - opts.addOption("eps", true, "A constant used in the denominator of AdaGrad [default 1e-6]"); - opts.addOption("scale", true, - "Internal scaling/descaling factor for cumulative weights [100]"); - return opts; - } - - @Override - protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException { - CommandLine cl = super.processOptions(argOIs); - if (cl == null) { - this.decay = 0.95f; - this.eps = 1e-6f; - this.scaling = 100f; - } else { - this.decay = Primitives.parseFloat(cl.getOptionValue("decay"), 0.95f); - this.eps = Primitives.parseFloat(cl.getOptionValue("eps"), 1E-6f); - this.scaling = Primitives.parseFloat(cl.getOptionValue("scale"), 100f); - } - return cl; - } - - @Override - protected final void checkTargetValue(final float target) throws UDFArgumentException { - if (target < 0.f || target > 1.f) { - throw new UDFArgumentException("target must be in range 0 to 1: " + target); - } - } - - @Override - protected void update(@Nonnull final FeatureValue[] features, float target, float predicted) { - float gradient = LossFunctions.logisticLoss(target, predicted); - onlineUpdate(features, gradient); - } - - @Override - protected void onlineUpdate(@Nonnull final FeatureValue[] features, float gradient) { - final float g_g = gradient * (gradient / scaling); - - for (FeatureValue f : features) {// w[i] += y * x[i] - if (f == null) { - continue; - } - Object x = f.getFeature(); - float xi = f.getValueAsFloat(); - - IWeightValue old_w = model.get(x); - IWeightValue new_w = getNewWeight(old_w, xi, gradient, g_g); - model.set(x, new_w); - } - } - - @Nonnull - protected IWeightValue getNewWeight(@Nullable final IWeightValue old, final float xi, - final float gradient, final float g_g) { - float old_w = 0.f; - float old_scaled_sum_sqgrad = 0.f; - float old_sum_squared_delta_x = 0.f; - if (old != null) { - old_w = old.get(); - old_scaled_sum_sqgrad = old.getSumOfSquaredGradients(); - old_sum_squared_delta_x = old.getSumOfSquaredDeltaX(); - } +@Deprecated +public final class AdaDeltaUDTF extends GeneralRegressionUDTF { - float new_scaled_sum_sq_grad = (decay * old_scaled_sum_sqgrad) + ((1.f - decay) * g_g); - float dx = (float) Math.sqrt((old_sum_squared_delta_x + eps) - / (old_scaled_sum_sqgrad * scaling + eps)) - * gradient; - float new_sum_squared_delta_x = (decay * old_sum_squared_delta_x) - + ((1.f - decay) * dx * dx); - float new_w = old_w + (dx * xi); - return new WeightValueParamsF2(new_w, new_scaled_sum_sq_grad, new_sum_squared_delta_x); + public AdaDeltaUDTF() { + optimizerOptions.put("optimizer", "AdaDelta"); } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/regression/AdaGradUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/regression/AdaGradUDTF.java b/core/src/main/java/hivemall/regression/AdaGradUDTF.java index de48d97..4b5f019 100644 --- a/core/src/main/java/hivemall/regression/AdaGradUDTF.java +++ b/core/src/main/java/hivemall/regression/AdaGradUDTF.java @@ -18,124 +18,14 @@ */ package hivemall.regression; -import hivemall.common.LossFunctions; -import hivemall.model.FeatureValue; -import hivemall.model.IWeightValue; -import hivemall.model.WeightValue.WeightValueParamsF1; -import hivemall.utils.lang.Primitives; - -import javax.annotation.Nonnull; -import javax.annotation.Nullable; - -import org.apache.commons.cli.CommandLine; -import org.apache.commons.cli.Options; -import org.apache.hadoop.hive.ql.exec.Description; -import org.apache.hadoop.hive.ql.exec.UDFArgumentException; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; - /** * ADAGRAD algorithm with element-wise adaptive learning rates. */ -@Description( - name = "train_adagrad_regr", - value = "_FUNC_(array<int|bigint|string> features, float target [, constant string options])" - + " - Returns a relation consists of <{int|bigint|string} feature, float weight>") -public final class AdaGradUDTF extends RegressionBaseUDTF { - - private float eta; - private float eps; - private float scaling; - - @Override - public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { - final int numArgs = argOIs.length; - if (numArgs != 2 && numArgs != 3) { - throw new UDFArgumentException( - "_FUNC_ takes 2 or 3 arguments: List<Text|Int|BitInt> features, float target [, constant string options]"); - } - - StructObjectInspector oi = super.initialize(argOIs); - model.configureParams(true, false, false); - return oi; - } - - @Override - protected Options getOptions() { - Options opts = super.getOptions(); - opts.addOption("eta", "eta0", true, "The initial learning rate [default 1.0]"); - opts.addOption("eps", true, "A constant used in the denominator of AdaGrad [default 1.0]"); - opts.addOption("scale", true, - "Internal scaling/descaling factor for cumulative weights [100]"); - return opts; - } - - @Override - protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException { - CommandLine cl = super.processOptions(argOIs); - if (cl == null) { - this.eta = 1.f; - this.eps = 1.f; - this.scaling = 100f; - } else { - this.eta = Primitives.parseFloat(cl.getOptionValue("eta"), 1.f); - this.eps = Primitives.parseFloat(cl.getOptionValue("eps"), 1.f); - this.scaling = Primitives.parseFloat(cl.getOptionValue("scale"), 100f); - } - return cl; - } - - @Override - protected final void checkTargetValue(final float target) throws UDFArgumentException { - if (target < 0.f || target > 1.f) { - throw new UDFArgumentException("target must be in range 0 to 1: " + target); - } - } - - @Override - protected void update(@Nonnull final FeatureValue[] features, float target, float predicted) { - float gradient = LossFunctions.logisticLoss(target, predicted); - onlineUpdate(features, gradient); - } - - @Override - protected void onlineUpdate(@Nonnull final FeatureValue[] features, float gradient) { - final float g_g = gradient * (gradient / scaling); - - for (FeatureValue f : features) {// w[i] += y * x[i] - if (f == null) { - continue; - } - Object x = f.getFeature(); - float xi = f.getValueAsFloat(); - - IWeightValue old_w = model.get(x); - IWeightValue new_w = getNewWeight(old_w, xi, gradient, g_g); - model.set(x, new_w); - } - } - - @Nonnull - protected IWeightValue getNewWeight(@Nullable final IWeightValue old, final float xi, - final float gradient, final float g_g) { - float old_w = 0.f; - float scaled_sum_sqgrad = 0.f; - - if (old != null) { - old_w = old.get(); - scaled_sum_sqgrad = old.getSumOfSquaredGradients(); - } - scaled_sum_sqgrad += g_g; - - float coeff = eta(scaled_sum_sqgrad) * gradient; - float new_w = old_w + (coeff * xi); - return new WeightValueParamsF1(new_w, scaled_sum_sqgrad); - } +@Deprecated +public final class AdaGradUDTF extends GeneralRegressionUDTF { - protected float eta(final double scaledSumOfSquaredGradients) { - double sumOfSquaredGradients = scaledSumOfSquaredGradients * scaling; - //return eta / (float) Math.sqrt(sumOfSquaredGradients); - return eta / (float) Math.sqrt(eps + sumOfSquaredGradients); // always less than eta0 + public AdaGradUDTF() { + optimizerOptions.put("optimizer", "AdaGrad"); } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java b/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java new file mode 100644 index 0000000..2a8b543 --- /dev/null +++ b/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java @@ -0,0 +1,125 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2015 Makoto YUI + * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package hivemall.regression; + +import java.util.HashMap; +import java.util.Map; +import javax.annotation.Nonnull; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Option; +import org.apache.commons.cli.Options; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; + +import hivemall.optimizer.LossFunctions; +import hivemall.model.FeatureValue; + +/** + * A general regression class with replaceable optimization functions. + */ +public class GeneralRegressionUDTF extends RegressionBaseUDTF { + + protected final Map<String, String> optimizerOptions; + + public GeneralRegressionUDTF() { + this.optimizerOptions = new HashMap<String, String>(); + // Set default values + optimizerOptions.put("optimizer", "adadelta"); + optimizerOptions.put("eta", "fixed"); + optimizerOptions.put("eta0", "1.0"); + optimizerOptions.put("t", "10000"); + optimizerOptions.put("power_t", "0.1"); + optimizerOptions.put("eps", "1e-6"); + optimizerOptions.put("rho", "0.95"); + optimizerOptions.put("scale", "100.0"); + optimizerOptions.put("lambda", "1.0"); + } + + @Override + public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { + if(argOIs.length != 2 && argOIs.length != 3) { + throw new UDFArgumentException( + this.getClass().getSimpleName() + + " takes 2 or 3 arguments: List<Text|Int|BitInt> features, float target " + + "[, constant string options]"); + } + return super.initialize(argOIs); + } + + @Override + protected Options getOptions() { + Options opts = super.getOptions(); + opts.addOption("optimizer", "opt", true, "Optimizer to update weights [default: adadelta]"); + opts.addOption("eta", true, " ETA estimator to compute delta [default: fixed]"); + opts.addOption("eta0", true, "Initial learning rate [default 1.0]"); + opts.addOption("t", "total_steps", true, "Total of n_samples * epochs time steps [default: 10000]"); + opts.addOption("power_t", true, "Exponent for inverse scaling learning rate [default 0.1]"); + opts.addOption("eps", true, "Denominator value of AdaDelta/AdaGrad [default 1e-6]"); + opts.addOption("rho", "decay", true, "Decay rate [default 0.95]"); + opts.addOption("scale", true, "Scaling factor for cumulative weights [100.0]"); + opts.addOption("regularization", "reg", true, "Regularization type [default not-defined]"); + opts.addOption("lambda", true, "Regularization term on weights [default 1.0]"); + return opts; + } + + @Override + protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException { + final CommandLine cl = super.processOptions(argOIs); + if(cl != null) { + for(final Option opt: cl.getOptions()) { + optimizerOptions.put(opt.getOpt(), opt.getValue()); + } + } + return cl; + } + + @Override + protected Map<String, String> getOptimzierOptions() { + return optimizerOptions; + } + + @Override + protected final void checkTargetValue(final float target) throws UDFArgumentException { + if(target < 0.f || target > 1.f) { + throw new UDFArgumentException("target must be in range 0 to 1: " + target); + } + } + + @Override + protected void update(@Nonnull final FeatureValue[] features, final float target, + final float predicted) { + if(is_mini_batch) { + throw new UnsupportedOperationException( + this.getClass().getSimpleName() + " supports no `is_mini_batch` mode"); + } else { + float loss = LossFunctions.logisticLoss(target, predicted); + for(FeatureValue f : features) { + Object feature = f.getFeature(); + float xi = f.getValueAsFloat(); + float weight = model.getWeight(feature); + float new_weight = optimizerImpl.computeUpdatedValue(feature, weight, -loss * xi); + model.setWeight(feature, new_weight); + } + optimizerImpl.proceedStep(); + } + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/regression/LogressUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/regression/LogressUDTF.java b/core/src/main/java/hivemall/regression/LogressUDTF.java index ca3da71..ea05da3 100644 --- a/core/src/main/java/hivemall/regression/LogressUDTF.java +++ b/core/src/main/java/hivemall/regression/LogressUDTF.java @@ -18,65 +18,12 @@ */ package hivemall.regression; -import hivemall.common.EtaEstimator; -import hivemall.common.LossFunctions; +@Deprecated +public final class LogressUDTF extends GeneralRegressionUDTF { -import org.apache.commons.cli.CommandLine; -import org.apache.commons.cli.Options; -import org.apache.hadoop.hive.ql.exec.Description; -import org.apache.hadoop.hive.ql.exec.UDFArgumentException; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; - -@Description( - name = "logress", - value = "_FUNC_(array<int|bigint|string> features, float target [, constant string options])" - + " - Returns a relation consists of <{int|bigint|string} feature, float weight>") -public final class LogressUDTF extends RegressionBaseUDTF { - - private EtaEstimator etaEstimator; - - @Override - public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { - final int numArgs = argOIs.length; - if (numArgs != 2 && numArgs != 3) { - throw new UDFArgumentException( - "LogressUDTF takes 2 or 3 arguments: List<Text|Int|BitInt> features, float target [, constant string options]"); - } - - return super.initialize(argOIs); - } - - @Override - protected Options getOptions() { - Options opts = super.getOptions(); - opts.addOption("t", "total_steps", true, "a total of n_samples * epochs time steps"); - opts.addOption("power_t", true, - "The exponent for inverse scaling learning rate [default 0.1]"); - opts.addOption("eta0", true, "The initial learning rate [default 0.1]"); - return opts; - } - - @Override - protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException { - CommandLine cl = super.processOptions(argOIs); - - this.etaEstimator = EtaEstimator.get(cl); - return cl; - } - - @Override - protected void checkTargetValue(final float target) throws UDFArgumentException { - if (target < 0.f || target > 1.f) { - throw new UDFArgumentException("target must be in range 0 to 1: " + target); - } - } - - @Override - protected float computeUpdate(final float target, final float predicted) { - float eta = etaEstimator.eta(count); - float gradient = LossFunctions.logisticLoss(target, predicted); - return eta * gradient; + public LogressUDTF() { + optimizerOptions.put("optimizer", "SGD"); + optimizerOptions.put("eta", "fixed"); } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/regression/PassiveAggressiveRegressionUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/regression/PassiveAggressiveRegressionUDTF.java b/core/src/main/java/hivemall/regression/PassiveAggressiveRegressionUDTF.java index c089946..e1afe2f 100644 --- a/core/src/main/java/hivemall/regression/PassiveAggressiveRegressionUDTF.java +++ b/core/src/main/java/hivemall/regression/PassiveAggressiveRegressionUDTF.java @@ -18,7 +18,7 @@ */ package hivemall.regression; -import hivemall.common.LossFunctions; +import hivemall.optimizer.LossFunctions; import hivemall.common.OnlineVariance; import hivemall.model.FeatureValue; import hivemall.model.PredictionResult; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java b/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java index 561d4f7..7dc8538 100644 --- a/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java +++ b/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java @@ -25,6 +25,7 @@ import hivemall.model.PredictionModel; import hivemall.model.PredictionResult; import hivemall.model.WeightValue; import hivemall.model.WeightValue.WeightValueWithCovar; +import hivemall.optimizer.Optimizer; import hivemall.utils.collections.IMapIterator; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.lang.FloatAccumulator; @@ -64,6 +65,7 @@ public abstract class RegressionBaseUDTF extends LearnerBaseUDTF { private boolean parseFeature; protected PredictionModel model; + protected Optimizer optimizerImpl; protected int count; // The accumulated delta of each weight values. @@ -87,6 +89,7 @@ public abstract class RegressionBaseUDTF extends LearnerBaseUDTF { if (preloadedModelFile != null) { loadPredictionModel(model, preloadedModelFile, featureOutputOI); } + this.optimizerImpl = createOptimizer(); this.count = 0; this.sampled = 0; @@ -235,7 +238,7 @@ public abstract class RegressionBaseUDTF extends LearnerBaseUDTF { protected void update(@Nonnull final FeatureValue[] features, final float target, final float predicted) { - final float grad = computeUpdate(target, predicted); + final float grad = computeGradient(target, predicted); if (is_mini_batch) { accumulateUpdate(features, grad); @@ -247,12 +250,9 @@ public abstract class RegressionBaseUDTF extends LearnerBaseUDTF { } } - protected float computeUpdate(float target, float predicted) { - throw new IllegalStateException(); - } - - protected IWeightValue getNewWeight(IWeightValue old_w, float delta) { - throw new IllegalStateException(); + // Compute a gradient by using a loss function in derived classes + protected float computeGradient(float target, float predicted) { + throw new UnsupportedOperationException(); } protected final void accumulateUpdate(@Nonnull final FeatureValue[] features, final float coeff) { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/test/java/hivemall/optimizer/OptimizerTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/optimizer/OptimizerTest.java b/core/src/test/java/hivemall/optimizer/OptimizerTest.java new file mode 100644 index 0000000..cfcfa79 --- /dev/null +++ b/core/src/test/java/hivemall/optimizer/OptimizerTest.java @@ -0,0 +1,172 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2015 Makoto YUI + * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package hivemall.optimizer; + +import java.util.HashMap; +import java.util.Map; +import java.util.Random; + +import org.junit.Assert; +import org.junit.Test; + +public final class OptimizerTest { + + @Test + public void testIllegalOptimizer() { + try { + final Map<String, String> emptyOptions = new HashMap<String, String>(); + DenseOptimizerFactory.create(1024, emptyOptions); + Assert.fail(); + } catch (IllegalArgumentException e) { + // tests passed + } + try { + final Map<String, String> options = new HashMap<String, String>(); + options.put("optimizer", "illegal"); + DenseOptimizerFactory.create(1024, options); + Assert.fail(); + } catch (IllegalArgumentException e) { + // tests passed + } + try { + final Map<String, String> emptyOptions = new HashMap<String, String>(); + SparseOptimizerFactory.create(1024, emptyOptions); + Assert.fail(); + } catch (IllegalArgumentException e) { + // tests passed + } + try { + final Map<String, String> options = new HashMap<String, String>(); + options.put("optimizer", "illegal"); + SparseOptimizerFactory.create(1024, options); + Assert.fail(); + } catch (IllegalArgumentException e) { + // tests passed + } + } + + @Test + public void testOptimizerFactory() { + final Map<String, String> options = new HashMap<String, String>(); + final String[] regTypes = new String[] {"NO", "L1", "L2"}; + for(final String regType : regTypes) { + options.put("optimizer", "SGD"); + options.put("regularization", regType); + Assert.assertTrue(DenseOptimizerFactory.create(8, options) instanceof Optimizer.SGD); + Assert.assertTrue(SparseOptimizerFactory.create(8, options) instanceof Optimizer.SGD); + } + for(final String regType : regTypes) { + options.put("optimizer", "AdaDelta"); + options.put("regularization", regType); + Assert.assertTrue(DenseOptimizerFactory.create(8, options) instanceof DenseOptimizerFactory.AdaDelta); + Assert.assertTrue(SparseOptimizerFactory.create(8, options) instanceof SparseOptimizerFactory.AdaDelta); + } + for(final String regType : regTypes) { + options.put("optimizer", "AdaGrad"); + options.put("regularization", regType); + Assert.assertTrue(DenseOptimizerFactory.create(8, options) instanceof DenseOptimizerFactory.AdaGrad); + Assert.assertTrue(SparseOptimizerFactory.create(8, options) instanceof SparseOptimizerFactory.AdaGrad); + } + for(final String regType : regTypes) { + options.put("optimizer", "Adam"); + options.put("regularization", regType); + Assert.assertTrue(DenseOptimizerFactory.create(8, options) instanceof DenseOptimizerFactory.Adam); + Assert.assertTrue(SparseOptimizerFactory.create(8, options) instanceof SparseOptimizerFactory.Adam); + } + + // We need special handling for `Optimizer#RDA` + options.put("optimizer", "AdaGrad"); + options.put("regularization", "RDA"); + Assert.assertTrue(DenseOptimizerFactory.create(8, options) instanceof DenseOptimizerFactory.RDA); + Assert.assertTrue(SparseOptimizerFactory.create(8, options) instanceof SparseOptimizerFactory.RDA); + + // `SGD`, `AdaDelta`, and `Adam` currently does not support `RDA` + for(final String optimizerType : new String[] {"SGD", "AdaDelta", "Adam"}) { + options.put("optimizer", optimizerType); + try { + DenseOptimizerFactory.create(8, options); + Assert.fail(); + } catch (IllegalArgumentException e) { + // tests passed + } + try { + SparseOptimizerFactory.create(8, options); + Assert.fail(); + } catch (IllegalArgumentException e) { + // tests passed + } + } + } + + private void testUpdateWeights(Optimizer optimizer, int numUpdates, int initSize) { + final float[] weights = new float[initSize * 2]; + final Random rnd = new Random(); + try { + for(int i = 0; i < numUpdates; i++) { + int index = rnd.nextInt(initSize); + weights[index] = optimizer.computeUpdatedValue(index, weights[index], 0.1f); + } + for(int i = 0; i < numUpdates; i++) { + int index = rnd.nextInt(initSize * 2); + weights[index] = optimizer.computeUpdatedValue(index, weights[index], 0.1f); + } + } catch(Exception e) { + Assert.fail("failed to update weights: " + e.getMessage()); + } + } + + private void testOptimizer(final Map<String, String> options, int numUpdates, int initSize) { + final Map<String, String> testOptions = new HashMap<String, String>(options); + final String[] regTypes = new String[] {"NO", "L1", "L2", "RDA"}; + for(final String regType : regTypes) { + options.put("regularization", regType); + testUpdateWeights(DenseOptimizerFactory.create(1024, testOptions), 65536, 1024); + testUpdateWeights(SparseOptimizerFactory.create(1024, testOptions), 65536, 1024); + } + } + + @Test + public void testSGDOptimizer() { + final Map<String, String> options = new HashMap<String, String>(); + options.put("optimizer", "SGD"); + testOptimizer(options, 65536, 1024); + } + + @Test + public void testAdaDeltaOptimizer() { + final Map<String, String> options = new HashMap<String, String>(); + options.put("optimizer", "AdaDelta"); + testOptimizer(options, 65536, 1024); + } + + @Test + public void testAdaGradOptimizer() { + final Map<String, String> options = new HashMap<String, String>(); + options.put("optimizer", "AdaGrad"); + testOptimizer(options, 65536, 1024); + } + + @Test + public void testAdamOptimizer() { + final Map<String, String> options = new HashMap<String, String>(); + options.put("optimizer", "Adam"); + testOptimizer(options, 65536, 1024); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/mixserv/src/test/java/hivemall/mix/server/MixServerTest.java ---------------------------------------------------------------------- diff --git a/mixserv/src/test/java/hivemall/mix/server/MixServerTest.java b/mixserv/src/test/java/hivemall/mix/server/MixServerTest.java index 0b1455c..38792d8 100644 --- a/mixserv/src/test/java/hivemall/mix/server/MixServerTest.java +++ b/mixserv/src/test/java/hivemall/mix/server/MixServerTest.java @@ -55,7 +55,7 @@ public class MixServerTest extends HivemallTestBase { waitForState(server, ServerState.RUNNING); - PredictionModel model = new DenseModel(16777216, false); + PredictionModel model = new DenseModel(16777216); model.configureClock(); MixClient client = null; try { @@ -93,7 +93,7 @@ public class MixServerTest extends HivemallTestBase { waitForState(server, ServerState.RUNNING); - PredictionModel model = new DenseModel(16777216, false); + PredictionModel model = new DenseModel(16777216); model.configureClock(); MixClient client = null; try { @@ -151,7 +151,7 @@ public class MixServerTest extends HivemallTestBase { } private static void invokeClient(String groupId, int serverPort) throws InterruptedException { - PredictionModel model = new DenseModel(16777216, false); + PredictionModel model = new DenseModel(16777216); model.configureClock(); MixClient client = null; try { @@ -296,10 +296,10 @@ public class MixServerTest extends HivemallTestBase { serverExec.shutdown(); } - private static void invokeClient01(String groupId, int serverPort, boolean denseModel, - boolean cancelMix) throws InterruptedException { - PredictionModel model = denseModel ? new DenseModel(100, false) : new SparseModel(100, - false); + private static void invokeClient01(String groupId, int serverPort, boolean denseModel, boolean cancelMix) + throws InterruptedException { + PredictionModel model = denseModel ? new DenseModel(100) + : new SparseModel(100, false); model.configureClock(); MixClient client = null; try { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/resources/ddl/define-all-as-permanent.hive ---------------------------------------------------------------------- diff --git a/resources/ddl/define-all-as-permanent.hive b/resources/ddl/define-all-as-permanent.hive index bab5a29..ccdace0 100644 --- a/resources/ddl/define-all-as-permanent.hive +++ b/resources/ddl/define-all-as-permanent.hive @@ -13,6 +13,9 @@ CREATE FUNCTION hivemall_version as 'hivemall.HivemallVersionUDF' USING JAR '${h -- binary classification -- --------------------------- +DROP FUNCTION IF EXISTS train_classifier; +CREATE FUNCTION train_classifier as 'hivemall.classifier.GeneralClassifierUDTF' USING JAR '${hivemall_jar}'; + DROP FUNCTION IF EXISTS train_perceptron; CREATE FUNCTION train_perceptron as 'hivemall.classifier.PerceptronUDTF' USING JAR '${hivemall_jar}'; @@ -45,7 +48,7 @@ CREATE FUNCTION train_adagrad_rda as 'hivemall.classifier.AdaGradRDAUDTF' USING -------------------------------- -- Multiclass classification -- --------------------------------- +-------------------------------- DROP FUNCTION IF EXISTS train_multiclass_perceptron; CREATE FUNCTION train_multiclass_perceptron as 'hivemall.classifier.multiclass.MulticlassPerceptronUDTF' USING JAR '${hivemall_jar}'; @@ -312,6 +315,13 @@ CREATE FUNCTION tf as 'hivemall.ftvec.text.TermFrequencyUDAF' USING JAR '${hivem -- Regression functions -- -------------------------- +DROP FUNCTION IF EXISTS train_regression; +CREATE FUNCTION train_regression as 'hivemall.classifier.GeneralRegressionUDTF' USING JAR '${hivemall_jar}'; + +DROP FUNCTION IF EXISTS train_logregr; +CREATE FUNCTION train_logregr as 'hivemall.regression.LogressUDTF' USING JAR '${hivemall_jar}'; + +-- alias for backward compatibility DROP FUNCTION IF EXISTS logress; CREATE FUNCTION logress as 'hivemall.regression.LogressUDTF' USING JAR '${hivemall_jar}'; @@ -599,3 +609,4 @@ CREATE FUNCTION xgboost_predict AS 'hivemall.xgboost.tools.XGBoostPredictUDTF' U DROP FUNCTION xgboost_multiclass_predict; CREATE FUNCTION xgboost_multiclass_predict AS 'hivemall.xgboost.tools.XGBoostMulticlassPredictUDTF' USING JAR '${hivemall_jar}'; +======= http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/resources/ddl/define-all.hive ---------------------------------------------------------------------- diff --git a/resources/ddl/define-all.hive b/resources/ddl/define-all.hive index 315b4d2..d60fd7f 100644 --- a/resources/ddl/define-all.hive +++ b/resources/ddl/define-all.hive @@ -9,6 +9,9 @@ create temporary function hivemall_version as 'hivemall.HivemallVersionUDF'; -- binary classification -- --------------------------- +drop temporary function train_classifier; +create temporary function train_classifier as 'hivemall.regression.GeneralClassifierUDTF'; + drop temporary function train_perceptron; create temporary function train_perceptron as 'hivemall.classifier.PerceptronUDTF'; @@ -308,6 +311,13 @@ create temporary function tf as 'hivemall.ftvec.text.TermFrequencyUDAF'; -- Regression functions -- -------------------------- +drop temporary function train_regression; +create temporary function train_regression as 'hivemall.regression.GeneralRegressionUDTF'; + +drop temporary function train_logregr; +create temporary function train_logregr as 'hivemall.regression.LogressUDTF'; + +-- alias for backward compatibility drop temporary function logress; create temporary function logress as 'hivemall.regression.LogressUDTF'; @@ -628,5 +638,3 @@ log(10, n_docs / max2(1,df_t)) + 1.0; create temporary macro tfidf(tf FLOAT, df_t DOUBLE, n_docs DOUBLE) tf * (log(10, n_docs / max2(1,df_t)) + 1.0); - -
