Revert some modifications
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/3620eb89 Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/3620eb89 Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/3620eb89 Branch: refs/heads/JIRA-22/pr-285 Commit: 3620eb89993db22ce8aee924d3cc0df33a5f9618 Parents: f81948c Author: Takeshi YAMAMURO <[email protected]> Authored: Wed Sep 21 01:52:22 2016 +0900 Committer: Takeshi YAMAMURO <[email protected]> Committed: Wed Sep 21 01:55:59 2016 +0900 ---------------------------------------------------------------------- .../src/main/java/hivemall/LearnerBaseUDTF.java | 33 ++ .../hivemall/classifier/AROWClassifierUDTF.java | 2 +- .../hivemall/classifier/AdaGradRDAUDTF.java | 125 +++++++- .../classifier/BinaryOnlineClassifierUDTF.java | 10 + .../classifier/GeneralClassifierUDTF.java | 1 + .../classifier/PassiveAggressiveUDTF.java | 2 +- .../main/java/hivemall/model/DenseModel.java | 86 ++++- .../main/java/hivemall/model/NewDenseModel.java | 293 +++++++++++++++++ .../model/NewSpaceEfficientDenseModel.java | 317 +++++++++++++++++++ .../java/hivemall/model/NewSparseModel.java | 197 ++++++++++++ .../java/hivemall/model/PredictionModel.java | 3 + .../model/SpaceEfficientDenseModel.java | 92 +++++- .../main/java/hivemall/model/SparseModel.java | 19 +- .../model/SynchronizedModelWrapper.java | 6 + .../hivemall/regression/AROWRegressionUDTF.java | 2 +- .../java/hivemall/regression/AdaDeltaUDTF.java | 118 ++++++- .../java/hivemall/regression/AdaGradUDTF.java | 119 ++++++- .../regression/GeneralRegressionUDTF.java | 1 + .../java/hivemall/regression/LogressUDTF.java | 65 +++- .../PassiveAggressiveRegressionUDTF.java | 2 +- .../hivemall/regression/RegressionBaseUDTF.java | 12 +- .../NewSpaceEfficientNewDenseModelTest.java | 60 ++++ .../model/SpaceEfficientDenseModelTest.java | 60 ---- .../java/hivemall/mix/server/MixServerTest.java | 14 +- .../hivemall/mix/server/MixServerSuite.scala | 4 +- 25 files changed, 1512 insertions(+), 131 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/LearnerBaseUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/LearnerBaseUDTF.java b/core/src/main/java/hivemall/LearnerBaseUDTF.java index 7fd5190..4cf3c7f 100644 --- a/core/src/main/java/hivemall/LearnerBaseUDTF.java +++ b/core/src/main/java/hivemall/LearnerBaseUDTF.java @@ -25,6 +25,9 @@ import hivemall.model.DenseModel; import hivemall.model.PredictionModel; import hivemall.model.SpaceEfficientDenseModel; import hivemall.model.SparseModel; +import hivemall.model.NewDenseModel; +import hivemall.model.NewSparseModel; +import hivemall.model.NewSpaceEfficientDenseModel; import hivemall.model.SynchronizedModelWrapper; import hivemall.model.WeightValue; import hivemall.model.WeightValue.WeightValueWithCovar; @@ -199,6 +202,36 @@ public abstract class LearnerBaseUDTF extends UDTFWithOptions { return model; } + protected PredictionModel createNewModel(String label) { + PredictionModel model; + final boolean useCovar = useCovariance(); + if (dense_model) { + if (disable_halffloat == false && model_dims > 16777216) { + logger.info("Build a space efficient dense model with " + model_dims + + " initial dimensions" + (useCovar ? " w/ covariances" : "")); + model = new NewSpaceEfficientDenseModel(model_dims, useCovar); + } else { + logger.info("Build a dense model with initial with " + model_dims + + " initial dimensions" + (useCovar ? " w/ covariances" : "")); + model = new NewDenseModel(model_dims, useCovar); + } + } else { + int initModelSize = getInitialModelSize(); + logger.info("Build a sparse model with initial with " + initModelSize + + " initial dimensions"); + model = new NewSparseModel(initModelSize, useCovar); + } + if (mixConnectInfo != null) { + model.configureClock(); + model = new SynchronizedModelWrapper(model); + MixClient client = configureMixClient(mixConnectInfo, label, model); + model.configureMix(client, mixCancel); + this.mixClient = client; + } + assert (model != null); + return model; + } + // If a model implements a optimizer, it must override this protected Map<String, String> getOptimzierOptions() { return null; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/classifier/AROWClassifierUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/classifier/AROWClassifierUDTF.java b/core/src/main/java/hivemall/classifier/AROWClassifierUDTF.java index ac8afcb..b42ab05 100644 --- a/core/src/main/java/hivemall/classifier/AROWClassifierUDTF.java +++ b/core/src/main/java/hivemall/classifier/AROWClassifierUDTF.java @@ -18,11 +18,11 @@ */ package hivemall.classifier; -import hivemall.optimizer.LossFunctions; import hivemall.model.FeatureValue; import hivemall.model.IWeightValue; import hivemall.model.PredictionResult; import hivemall.model.WeightValue.WeightValueWithCovar; +import hivemall.optimizer.LossFunctions; import javax.annotation.Nonnull; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/classifier/AdaGradRDAUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/classifier/AdaGradRDAUDTF.java b/core/src/main/java/hivemall/classifier/AdaGradRDAUDTF.java index a6714f4..b512a34 100644 --- a/core/src/main/java/hivemall/classifier/AdaGradRDAUDTF.java +++ b/core/src/main/java/hivemall/classifier/AdaGradRDAUDTF.java @@ -18,13 +18,128 @@ */ package hivemall.classifier; +import hivemall.model.FeatureValue; +import hivemall.model.IWeightValue; +import hivemall.model.WeightValue.WeightValueParamsF2; +import hivemall.optimizer.LossFunctions; +import hivemall.utils.lang.Primitives; + +import javax.annotation.Nonnull; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Options; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; + +/** + * @deprecated Use {@link hivemall.classifier.GeneralClassifierUDTF} instead + */ @Deprecated -public final class AdaGradRDAUDTF extends GeneralClassifierUDTF { +@Description(name = "train_adagrad_rda", + value = "_FUNC_(list<string|int|bigint> features, int label [, const string options])" + + " - Returns a relation consists of <string|int|bigint feature, float weight>", + extended = "Build a prediction model by Adagrad+RDA regularization binary classifier") +public final class AdaGradRDAUDTF extends BinaryOnlineClassifierUDTF { + + private float eta; + private float lambda; + private float scaling; + + @Override + public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { + final int numArgs = argOIs.length; + if (numArgs != 2 && numArgs != 3) { + throw new UDFArgumentException( + "_FUNC_ takes 2 or 3 arguments: List<Text|Int|BitInt> features, int label [, constant string options]"); + } + + StructObjectInspector oi = super.initialize(argOIs); + model.configureParams(true, false, true); + return oi; + } + + @Override + protected Options getOptions() { + Options opts = super.getOptions(); + opts.addOption("eta", "eta0", true, "The learning rate \\eta [default 0.1]"); + opts.addOption("lambda", true, "lambda constant of RDA [default: 1E-6f]"); + opts.addOption("scale", true, + "Internal scaling/descaling factor for cumulative weights [default: 100]"); + return opts; + } - public AdaGradRDAUDTF() { - optimizerOptions.put("optimizer", "AdaGrad"); - optimizerOptions.put("regularization", "RDA"); - optimizerOptions.put("lambda", "1e-6"); + @Override + protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException { + CommandLine cl = super.processOptions(argOIs); + if (cl == null) { + this.eta = 0.1f; + this.lambda = 1E-6f; + this.scaling = 100f; + } else { + this.eta = Primitives.parseFloat(cl.getOptionValue("eta"), 0.1f); + this.lambda = Primitives.parseFloat(cl.getOptionValue("lambda"), 1E-6f); + this.scaling = Primitives.parseFloat(cl.getOptionValue("scale"), 100f); + } + return cl; } + @Override + protected void train(@Nonnull final FeatureValue[] features, final int label) { + final float y = label > 0 ? 1.f : -1.f; + + float p = predict(features); + float loss = LossFunctions.hingeLoss(p, y); // 1.0 - y * p + if (loss <= 0.f) { // max(0, 1 - y * p) + return; + } + // subgradient => -y * W dot xi + update(features, y, count); + } + + protected void update(@Nonnull final FeatureValue[] features, final float y, final int t) { + for (FeatureValue f : features) {// w[f] += y * x[f] + if (f == null) { + continue; + } + Object x = f.getFeature(); + float xi = f.getValueAsFloat(); + + updateWeight(x, xi, y, t); + } + } + + protected void updateWeight(@Nonnull final Object x, final float xi, final float y, + final float t) { + final float gradient = -y * xi; + final float scaled_gradient = gradient * scaling; + + float scaled_sum_sqgrad = 0.f; + float scaled_sum_grad = 0.f; + IWeightValue old = model.get(x); + if (old != null) { + scaled_sum_sqgrad = old.getSumOfSquaredGradients(); + scaled_sum_grad = old.getSumOfGradients(); + } + scaled_sum_grad += scaled_gradient; + scaled_sum_sqgrad += (scaled_gradient * scaled_gradient); + + float sum_grad = scaled_sum_grad * scaling; + double sum_sqgrad = scaled_sum_sqgrad * scaling; + + // sign(u_{t,i}) + float sign = (sum_grad > 0.f) ? 1.f : -1.f; + // |u_{t,i}|/t - \lambda + float meansOfGradients = sign * sum_grad / t - lambda; + if (meansOfGradients < 0.f) { + // x_{t,i} = 0 + model.delete(x); + } else { + // x_{t,i} = -sign(u_{t,i}) * \frac{\eta t}{\sqrt{G_{t,ii}}}(|u_{t,i}|/t - \lambda) + float weight = -1.f * sign * eta * t * meansOfGradients / (float) Math.sqrt(sum_sqgrad); + IWeightValue new_w = new WeightValueParamsF2(weight, scaled_sum_sqgrad, scaled_sum_grad); + model.set(x, new_w); + } + } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java b/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java index 0ee5d5f..efeeb9d 100644 --- a/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java +++ b/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java @@ -60,6 +60,16 @@ public abstract class BinaryOnlineClassifierUDTF extends LearnerBaseUDTF { protected Optimizer optimizerImpl; protected int count; + private boolean enableNewModel; + + public BinaryOnlineClassifierUDTF() { + this.enableNewModel = false; + } + + public BinaryOnlineClassifierUDTF(boolean enableNewModel) { + this.enableNewModel = enableNewModel; + } + @Override public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { if (argOIs.length < 2) { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java b/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java index feebadd..12bd481 100644 --- a/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java +++ b/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java @@ -39,6 +39,7 @@ public class GeneralClassifierUDTF extends BinaryOnlineClassifierUDTF { protected final Map<String, String> optimizerOptions; public GeneralClassifierUDTF() { + super(true); // This enables new model interfaces this.optimizerOptions = new HashMap<String, String>(); // Set default values optimizerOptions.put("optimizer", "adagrad"); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/classifier/PassiveAggressiveUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/classifier/PassiveAggressiveUDTF.java b/core/src/main/java/hivemall/classifier/PassiveAggressiveUDTF.java index 9e404cd..191a7b5 100644 --- a/core/src/main/java/hivemall/classifier/PassiveAggressiveUDTF.java +++ b/core/src/main/java/hivemall/classifier/PassiveAggressiveUDTF.java @@ -18,9 +18,9 @@ */ package hivemall.classifier; -import hivemall.optimizer.LossFunctions; import hivemall.model.FeatureValue; import hivemall.model.PredictionResult; +import hivemall.optimizer.LossFunctions; import javax.annotation.Nonnull; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/model/DenseModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/DenseModel.java b/core/src/main/java/hivemall/model/DenseModel.java index 6956875..f142cc1 100644 --- a/core/src/main/java/hivemall/model/DenseModel.java +++ b/core/src/main/java/hivemall/model/DenseModel.java @@ -18,18 +18,21 @@ */ package hivemall.model; -import java.util.Arrays; -import javax.annotation.Nonnull; - -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; - +import hivemall.model.WeightValue.WeightValueParamsF1; +import hivemall.model.WeightValue.WeightValueParamsF2; import hivemall.model.WeightValue.WeightValueWithCovar; import hivemall.utils.collections.IMapIterator; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.lang.Copyable; import hivemall.utils.math.MathUtils; +import java.util.Arrays; + +import javax.annotation.Nonnull; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + public final class DenseModel extends AbstractPredictionModel { private static final Log logger = LogFactory.getLog(DenseModel.class); @@ -37,6 +40,13 @@ public final class DenseModel extends AbstractPredictionModel { private float[] weights; private float[] covars; + // optional values for adagrad + private float[] sum_of_squared_gradients; + // optional value for adadelta + private float[] sum_of_squared_delta_x; + // optional value for adagrad+rda + private float[] sum_of_gradients; + // optional value for MIX private short[] clocks; private byte[] deltaUpdates; @@ -57,6 +67,9 @@ public final class DenseModel extends AbstractPredictionModel { } else { this.covars = null; } + this.sum_of_squared_gradients = null; + this.sum_of_squared_delta_x = null; + this.sum_of_gradients = null; this.clocks = null; this.deltaUpdates = null; } @@ -72,6 +85,20 @@ public final class DenseModel extends AbstractPredictionModel { } @Override + public void configureParams(boolean sum_of_squared_gradients, boolean sum_of_squared_delta_x, + boolean sum_of_gradients) { + if (sum_of_squared_gradients) { + this.sum_of_squared_gradients = new float[size]; + } + if (sum_of_squared_delta_x) { + this.sum_of_squared_delta_x = new float[size]; + } + if (sum_of_gradients) { + this.sum_of_gradients = new float[size]; + } + } + + @Override public void configureClock() { if (clocks == null) { this.clocks = new short[size]; @@ -102,7 +129,16 @@ public final class DenseModel extends AbstractPredictionModel { this.covars = Arrays.copyOf(covars, newSize); Arrays.fill(covars, oldSize, newSize, 1.f); } - if(clocks != null) { + if (sum_of_squared_gradients != null) { + this.sum_of_squared_gradients = Arrays.copyOf(sum_of_squared_gradients, newSize); + } + if (sum_of_squared_delta_x != null) { + this.sum_of_squared_delta_x = Arrays.copyOf(sum_of_squared_delta_x, newSize); + } + if (sum_of_gradients != null) { + this.sum_of_gradients = Arrays.copyOf(sum_of_gradients, newSize); + } + if (clocks != null) { this.clocks = Arrays.copyOf(clocks, newSize); this.deltaUpdates = Arrays.copyOf(deltaUpdates, newSize); } @@ -116,7 +152,17 @@ public final class DenseModel extends AbstractPredictionModel { if (i >= size) { return null; } - if(covars != null) { + if (sum_of_squared_gradients != null) { + if (sum_of_squared_delta_x != null) { + return (T) new WeightValueParamsF2(weights[i], sum_of_squared_gradients[i], + sum_of_squared_delta_x[i]); + } else if (sum_of_gradients != null) { + return (T) new WeightValueParamsF2(weights[i], sum_of_squared_gradients[i], + sum_of_gradients[i]); + } else { + return (T) new WeightValueParamsF1(weights[i], sum_of_squared_gradients[i]); + } + } else if (covars != null) { return (T) new WeightValueWithCovar(weights[i], covars[i]); } else { return (T) new WeightValue(weights[i]); @@ -135,6 +181,15 @@ public final class DenseModel extends AbstractPredictionModel { covar = value.getCovariance(); covars[i] = covar; } + if (sum_of_squared_gradients != null) { + sum_of_squared_gradients[i] = value.getSumOfSquaredGradients(); + } + if (sum_of_squared_delta_x != null) { + sum_of_squared_delta_x[i] = value.getSumOfSquaredDeltaX(); + } + if (sum_of_gradients != null) { + sum_of_gradients[i] = value.getSumOfGradients(); + } short clock = 0; int delta = 0; if (clocks != null && value.isTouched()) { @@ -158,6 +213,15 @@ public final class DenseModel extends AbstractPredictionModel { if (covars != null) { covars[i] = 1.f; } + if (sum_of_squared_gradients != null) { + sum_of_squared_gradients[i] = 0.f; + } + if (sum_of_squared_delta_x != null) { + sum_of_squared_delta_x[i] = 0.f; + } + if (sum_of_gradients != null) { + sum_of_gradients[i] = 0.f; + } // avoid clock/delta } @@ -171,10 +235,8 @@ public final class DenseModel extends AbstractPredictionModel { } @Override - public void setWeight(Object feature, float value) { - int i = HiveUtils.parseInt(feature); - ensureCapacity(i); - weights[i] = value; + public void setWeight(@Nonnull Object feature, float value) { + throw new UnsupportedOperationException(); } @Override http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/model/NewDenseModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/NewDenseModel.java b/core/src/main/java/hivemall/model/NewDenseModel.java new file mode 100644 index 0000000..920794c --- /dev/null +++ b/core/src/main/java/hivemall/model/NewDenseModel.java @@ -0,0 +1,293 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2015 Makoto YUI + * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package hivemall.model; + +import java.util.Arrays; +import javax.annotation.Nonnull; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import hivemall.model.WeightValue.WeightValueWithCovar; +import hivemall.utils.collections.IMapIterator; +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.lang.Copyable; +import hivemall.utils.math.MathUtils; + +public final class NewDenseModel extends AbstractPredictionModel { + private static final Log logger = LogFactory.getLog(NewDenseModel.class); + + private int size; + private float[] weights; + private float[] covars; + + // optional value for MIX + private short[] clocks; + private byte[] deltaUpdates; + + public NewDenseModel(int ndims) { + this(ndims, false); + } + + public NewDenseModel(int ndims, boolean withCovar) { + super(); + int size = ndims + 1; + this.size = size; + this.weights = new float[size]; + if (withCovar) { + float[] covars = new float[size]; + Arrays.fill(covars, 1f); + this.covars = covars; + } else { + this.covars = null; + } + this.clocks = null; + this.deltaUpdates = null; + } + + @Override + protected boolean isDenseModel() { + return true; + } + + @Override + public boolean hasCovariance() { + return covars != null; + } + + @Override + public void configureParams(boolean sum_of_squared_gradients, boolean sum_of_squared_delta_x, + boolean sum_of_gradients) {} + + @Override + public void configureClock() { + if (clocks == null) { + this.clocks = new short[size]; + this.deltaUpdates = new byte[size]; + } + } + + @Override + public boolean hasClock() { + return clocks != null; + } + + @Override + public void resetDeltaUpdates(int feature) { + deltaUpdates[feature] = 0; + } + + private void ensureCapacity(final int index) { + if (index >= size) { + int bits = MathUtils.bitsRequired(index); + int newSize = (1 << bits) + 1; + int oldSize = size; + logger.info("Expands internal array size from " + oldSize + " to " + newSize + " (" + + bits + " bits)"); + this.size = newSize; + this.weights = Arrays.copyOf(weights, newSize); + if (covars != null) { + this.covars = Arrays.copyOf(covars, newSize); + Arrays.fill(covars, oldSize, newSize, 1.f); + } + if(clocks != null) { + this.clocks = Arrays.copyOf(clocks, newSize); + this.deltaUpdates = Arrays.copyOf(deltaUpdates, newSize); + } + } + } + + @SuppressWarnings("unchecked") + @Override + public <T extends IWeightValue> T get(Object feature) { + final int i = HiveUtils.parseInt(feature); + if (i >= size) { + return null; + } + if(covars != null) { + return (T) new WeightValueWithCovar(weights[i], covars[i]); + } else { + return (T) new WeightValue(weights[i]); + } + } + + @Override + public <T extends IWeightValue> void set(Object feature, T value) { + int i = HiveUtils.parseInt(feature); + ensureCapacity(i); + float weight = value.get(); + weights[i] = weight; + float covar = 1.f; + boolean hasCovar = value.hasCovariance(); + if (hasCovar) { + covar = value.getCovariance(); + covars[i] = covar; + } + short clock = 0; + int delta = 0; + if (clocks != null && value.isTouched()) { + clock = (short) (clocks[i] + 1); + clocks[i] = clock; + delta = deltaUpdates[i] + 1; + assert (delta > 0) : delta; + deltaUpdates[i] = (byte) delta; + } + + onUpdate(i, weight, covar, clock, delta, hasCovar); + } + + @Override + public void delete(@Nonnull Object feature) { + final int i = HiveUtils.parseInt(feature); + if (i >= size) { + return; + } + weights[i] = 0.f; + if (covars != null) { + covars[i] = 1.f; + } + // avoid clock/delta + } + + @Override + public float getWeight(Object feature) { + int i = HiveUtils.parseInt(feature); + if (i >= size) { + return 0f; + } + return weights[i]; + } + + @Override + public void setWeight(Object feature, float value) { + int i = HiveUtils.parseInt(feature); + ensureCapacity(i); + weights[i] = value; + } + + @Override + public float getCovariance(Object feature) { + int i = HiveUtils.parseInt(feature); + if (i >= size) { + return 1f; + } + return covars[i]; + } + + @Override + protected void _set(Object feature, float weight, short clock) { + int i = ((Integer) feature).intValue(); + ensureCapacity(i); + weights[i] = weight; + clocks[i] = clock; + deltaUpdates[i] = 0; + } + + @Override + protected void _set(Object feature, float weight, float covar, short clock) { + int i = ((Integer) feature).intValue(); + ensureCapacity(i); + weights[i] = weight; + covars[i] = covar; + clocks[i] = clock; + deltaUpdates[i] = 0; + } + + @Override + public int size() { + return size; + } + + @Override + public boolean contains(Object feature) { + int i = HiveUtils.parseInt(feature); + if (i >= size) { + return false; + } + float w = weights[i]; + return w != 0.f; + } + + @SuppressWarnings("unchecked") + @Override + public <K, V extends IWeightValue> IMapIterator<K, V> entries() { + return (IMapIterator<K, V>) new Itr(); + } + + private final class Itr implements IMapIterator<Number, IWeightValue> { + + private int cursor; + private final WeightValueWithCovar tmpWeight; + + private Itr() { + this.cursor = -1; + this.tmpWeight = new WeightValueWithCovar(); + } + + @Override + public boolean hasNext() { + return cursor < size; + } + + @Override + public int next() { + ++cursor; + if (!hasNext()) { + return -1; + } + return cursor; + } + + @Override + public Integer getKey() { + return cursor; + } + + @Override + public IWeightValue getValue() { + if (covars == null) { + float w = weights[cursor]; + WeightValue v = new WeightValue(w); + v.setTouched(w != 0f); + return v; + } else { + float w = weights[cursor]; + float cov = covars[cursor]; + WeightValueWithCovar v = new WeightValueWithCovar(w, cov); + v.setTouched(w != 0.f || cov != 1.f); + return v; + } + } + + @Override + public <T extends Copyable<IWeightValue>> void getValue(T probe) { + float w = weights[cursor]; + tmpWeight.value = w; + float cov = 1.f; + if (covars != null) { + cov = covars[cursor]; + tmpWeight.setCovariance(cov); + } + tmpWeight.setTouched(w != 0.f || cov != 1.f); + probe.copyFrom(tmpWeight); + } + + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/model/NewSpaceEfficientDenseModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/NewSpaceEfficientDenseModel.java b/core/src/main/java/hivemall/model/NewSpaceEfficientDenseModel.java new file mode 100644 index 0000000..48eb62a --- /dev/null +++ b/core/src/main/java/hivemall/model/NewSpaceEfficientDenseModel.java @@ -0,0 +1,317 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2015 Makoto YUI + * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package hivemall.model; + +import hivemall.model.WeightValue.WeightValueWithCovar; +import hivemall.utils.collections.IMapIterator; +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.lang.Copyable; +import hivemall.utils.lang.HalfFloat; +import hivemall.utils.math.MathUtils; + +import java.util.Arrays; +import javax.annotation.Nonnull; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +public final class NewSpaceEfficientDenseModel extends AbstractPredictionModel { + private static final Log logger = LogFactory.getLog(NewSpaceEfficientDenseModel.class); + + private int size; + private short[] weights; + private short[] covars; + + // optional value for MIX + private short[] clocks; + private byte[] deltaUpdates; + + public NewSpaceEfficientDenseModel(int ndims) { + this(ndims, false); + } + + public NewSpaceEfficientDenseModel(int ndims, boolean withCovar) { + super(); + int size = ndims + 1; + this.size = size; + this.weights = new short[size]; + if (withCovar) { + short[] covars = new short[size]; + Arrays.fill(covars, HalfFloat.ONE); + this.covars = covars; + } else { + this.covars = null; + } + this.clocks = null; + this.deltaUpdates = null; + } + + @Override + protected boolean isDenseModel() { + return true; + } + + @Override + public boolean hasCovariance() { + return covars != null; + } + + @Override + public void configureParams(boolean sum_of_squared_gradients, boolean sum_of_squared_delta_x, + boolean sum_of_gradients) {} + + @Override + public void configureClock() { + if (clocks == null) { + this.clocks = new short[size]; + this.deltaUpdates = new byte[size]; + } + } + + @Override + public boolean hasClock() { + return clocks != null; + } + + @Override + public void resetDeltaUpdates(int feature) { + deltaUpdates[feature] = 0; + } + + private float getWeight(final int i) { + final short w = weights[i]; + return (w == HalfFloat.ZERO) ? HalfFloat.ZERO : HalfFloat.halfFloatToFloat(w); + } + + private float getCovar(final int i) { + return HalfFloat.halfFloatToFloat(covars[i]); + } + + private void _setWeight(final int i, final float v) { + if(Math.abs(v) >= HalfFloat.MAX_FLOAT) { + throw new IllegalArgumentException("Acceptable maximum weight is " + + HalfFloat.MAX_FLOAT + ": " + v); + } + weights[i] = HalfFloat.floatToHalfFloat(v); + } + + private void setCovar(final int i, final float v) { + HalfFloat.checkRange(v); + covars[i] = HalfFloat.floatToHalfFloat(v); + } + + private void ensureCapacity(final int index) { + if (index >= size) { + int bits = MathUtils.bitsRequired(index); + int newSize = (1 << bits) + 1; + int oldSize = size; + logger.info("Expands internal array size from " + oldSize + " to " + newSize + " (" + + bits + " bits)"); + this.size = newSize; + this.weights = Arrays.copyOf(weights, newSize); + if (covars != null) { + this.covars = Arrays.copyOf(covars, newSize); + Arrays.fill(covars, oldSize, newSize, HalfFloat.ONE); + } + if(clocks != null) { + this.clocks = Arrays.copyOf(clocks, newSize); + this.deltaUpdates = Arrays.copyOf(deltaUpdates, newSize); + } + } + } + + @SuppressWarnings("unchecked") + @Override + public <T extends IWeightValue> T get(Object feature) { + final int i = HiveUtils.parseInt(feature); + if (i >= size) { + return null; + } + + if(covars != null) { + return (T) new WeightValueWithCovar(getWeight(i), getCovar(i)); + } else { + return (T) new WeightValue(getWeight(i)); + } + } + + @Override + public <T extends IWeightValue> void set(Object feature, T value) { + int i = HiveUtils.parseInt(feature); + ensureCapacity(i); + float weight = value.get(); + _setWeight(i, weight); + float covar = 1.f; + boolean hasCovar = value.hasCovariance(); + if (hasCovar) { + covar = value.getCovariance(); + setCovar(i, covar); + } + short clock = 0; + int delta = 0; + if (clocks != null && value.isTouched()) { + clock = (short) (clocks[i] + 1); + clocks[i] = clock; + delta = deltaUpdates[i] + 1; + assert (delta > 0) : delta; + deltaUpdates[i] = (byte) delta; + } + + onUpdate(i, weight, covar, clock, delta, hasCovar); + } + + @Override + public void delete(@Nonnull Object feature) { + final int i = HiveUtils.parseInt(feature); + if (i >= size) { + return; + } + _setWeight(i, 0.f); + if(covars != null) { + setCovar(i, 1.f); + } + // avoid clock/delta + } + + @Override + public float getWeight(Object feature) { + int i = HiveUtils.parseInt(feature); + if (i >= size) { + return 0f; + } + return getWeight(i); + } + + @Override + public void setWeight(Object feature, float value) { + int i = HiveUtils.parseInt(feature); + ensureCapacity(i); + _setWeight(i, value); + } + + @Override + public float getCovariance(Object feature) { + int i = HiveUtils.parseInt(feature); + if (i >= size) { + return 1f; + } + return getCovar(i); + } + + @Override + protected void _set(Object feature, float weight, short clock) { + int i = ((Integer) feature).intValue(); + ensureCapacity(i); + _setWeight(i, weight); + clocks[i] = clock; + deltaUpdates[i] = 0; + } + + @Override + protected void _set(Object feature, float weight, float covar, short clock) { + int i = ((Integer) feature).intValue(); + ensureCapacity(i); + _setWeight(i, weight); + setCovar(i, covar); + clocks[i] = clock; + deltaUpdates[i] = 0; + } + + @Override + public int size() { + return size; + } + + @Override + public boolean contains(Object feature) { + int i = HiveUtils.parseInt(feature); + if (i >= size) { + return false; + } + float w = getWeight(i); + return w != 0.f; + } + + @SuppressWarnings("unchecked") + @Override + public <K, V extends IWeightValue> IMapIterator<K, V> entries() { + return (IMapIterator<K, V>) new Itr(); + } + + private final class Itr implements IMapIterator<Number, IWeightValue> { + + private int cursor; + private final WeightValueWithCovar tmpWeight; + + private Itr() { + this.cursor = -1; + this.tmpWeight = new WeightValueWithCovar(); + } + + @Override + public boolean hasNext() { + return cursor < size; + } + + @Override + public int next() { + ++cursor; + if (!hasNext()) { + return -1; + } + return cursor; + } + + @Override + public Integer getKey() { + return cursor; + } + + @Override + public IWeightValue getValue() { + if (covars == null) { + float w = getWeight(cursor); + WeightValue v = new WeightValue(w); + v.setTouched(w != 0f); + return v; + } else { + float w = getWeight(cursor); + float cov = getCovar(cursor); + WeightValueWithCovar v = new WeightValueWithCovar(w, cov); + v.setTouched(w != 0.f || cov != 1.f); + return v; + } + } + + @Override + public <T extends Copyable<IWeightValue>> void getValue(T probe) { + float w = getWeight(cursor); + tmpWeight.value = w; + float cov = 1.f; + if (covars != null) { + cov = getCovar(cursor); + tmpWeight.setCovariance(cov); + } + tmpWeight.setTouched(w != 0.f || cov != 1.f); + probe.copyFrom(tmpWeight); + } + + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/model/NewSparseModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/NewSparseModel.java b/core/src/main/java/hivemall/model/NewSparseModel.java new file mode 100644 index 0000000..4c21830 --- /dev/null +++ b/core/src/main/java/hivemall/model/NewSparseModel.java @@ -0,0 +1,197 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2015 Makoto YUI + * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package hivemall.model; + +import hivemall.model.WeightValueWithClock.WeightValueParamsF1Clock; +import hivemall.model.WeightValueWithClock.WeightValueParamsF2Clock; +import hivemall.model.WeightValueWithClock.WeightValueWithCovarClock; +import hivemall.utils.collections.IMapIterator; +import hivemall.utils.collections.OpenHashMap; + +import javax.annotation.Nonnull; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +public final class NewSparseModel extends AbstractPredictionModel { + private static final Log logger = LogFactory.getLog(NewSparseModel.class); + + private final OpenHashMap<Object, IWeightValue> weights; + private final boolean hasCovar; + private boolean clockEnabled; + + public NewSparseModel(int size) { + this(size, false); + } + + public NewSparseModel(int size, boolean hasCovar) { + super(); + this.weights = new OpenHashMap<Object, IWeightValue>(size); + this.hasCovar = hasCovar; + this.clockEnabled = false; + } + + @Override + protected boolean isDenseModel() { + return false; + } + + @Override + public boolean hasCovariance() { + return hasCovar; + } + + @Override + public void configureParams(boolean sum_of_squared_gradients, boolean sum_of_squared_delta_x, + boolean sum_of_gradients) {} + + @Override + public void configureClock() { + this.clockEnabled = true; + } + + @Override + public boolean hasClock() { + return clockEnabled; + } + + @SuppressWarnings("unchecked") + @Override + public <T extends IWeightValue> T get(final Object feature) { + return (T) weights.get(feature); + } + + @Override + public <T extends IWeightValue> void set(final Object feature, final T value) { + assert (feature != null); + assert (value != null); + + final IWeightValue wrapperValue = wrapIfRequired(value); + + if (clockEnabled && value.isTouched()) { + IWeightValue old = weights.get(feature); + if (old != null) { + short newclock = (short) (old.getClock() + (short) 1); + wrapperValue.setClock(newclock); + int newDelta = old.getDeltaUpdates() + 1; + wrapperValue.setDeltaUpdates((byte) newDelta); + } + } + weights.put(feature, wrapperValue); + + onUpdate(feature, wrapperValue); + } + + @Override + public void delete(@Nonnull Object feature) { + weights.remove(feature); + } + + private IWeightValue wrapIfRequired(final IWeightValue value) { + final IWeightValue wrapper; + if (clockEnabled) { + switch (value.getType()) { + case NoParams: + wrapper = new WeightValueWithClock(value); + break; + case ParamsCovar: + wrapper = new WeightValueWithCovarClock(value); + break; + case ParamsF1: + wrapper = new WeightValueParamsF1Clock(value); + break; + case ParamsF2: + wrapper = new WeightValueParamsF2Clock(value); + break; + default: + throw new IllegalStateException("Unexpected value type: " + value.getType()); + } + } else { + wrapper = value; + } + return wrapper; + } + + @Override + public float getWeight(final Object feature) { + IWeightValue v = weights.get(feature); + return v == null ? 0.f : v.get(); + } + + @Override + public void setWeight(Object feature, float value) { + if(weights.containsKey(feature)) { + IWeightValue weight = weights.get(feature); + weight.set(value); + } else { + IWeightValue weight = new WeightValue(value); + weight.set(value); + weights.put(feature, weight); + } + } + + @Override + public float getCovariance(final Object feature) { + IWeightValue v = weights.get(feature); + return v == null ? 1.f : v.getCovariance(); + } + + @Override + protected void _set(final Object feature, final float weight, final short clock) { + final IWeightValue w = weights.get(feature); + if (w == null) { + logger.warn("Previous weight not found: " + feature); + throw new IllegalStateException("Previous weight not found " + feature); + } + w.set(weight); + w.setClock(clock); + w.setDeltaUpdates(BYTE0); + } + + @Override + protected void _set(final Object feature, final float weight, final float covar, + final short clock) { + final IWeightValue w = weights.get(feature); + if (w == null) { + logger.warn("Previous weight not found: " + feature); + throw new IllegalStateException("Previous weight not found: " + feature); + } + w.set(weight); + w.setCovariance(covar); + w.setClock(clock); + w.setDeltaUpdates(BYTE0); + } + + @Override + public int size() { + return weights.size(); + } + + @Override + public boolean contains(final Object feature) { + return weights.containsKey(feature); + } + + @SuppressWarnings("unchecked") + @Override + public <K, V extends IWeightValue> IMapIterator<K, V> entries() { + return (IMapIterator<K, V>) weights.entries(); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/model/PredictionModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/PredictionModel.java b/core/src/main/java/hivemall/model/PredictionModel.java index 8d8dd2b..ea82f62 100644 --- a/core/src/main/java/hivemall/model/PredictionModel.java +++ b/core/src/main/java/hivemall/model/PredictionModel.java @@ -34,6 +34,9 @@ public interface PredictionModel extends MixedModel { boolean hasCovariance(); + void configureParams(boolean sum_of_squared_gradients, boolean sum_of_squared_delta_x, + boolean sum_of_gradients); + void configureClock(); boolean hasClock(); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/model/SpaceEfficientDenseModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/SpaceEfficientDenseModel.java b/core/src/main/java/hivemall/model/SpaceEfficientDenseModel.java index 8b668e7..caa9fea 100644 --- a/core/src/main/java/hivemall/model/SpaceEfficientDenseModel.java +++ b/core/src/main/java/hivemall/model/SpaceEfficientDenseModel.java @@ -18,6 +18,8 @@ */ package hivemall.model; +import hivemall.model.WeightValue.WeightValueParamsF1; +import hivemall.model.WeightValue.WeightValueParamsF2; import hivemall.model.WeightValue.WeightValueWithCovar; import hivemall.utils.collections.IMapIterator; import hivemall.utils.hadoop.HiveUtils; @@ -26,6 +28,7 @@ import hivemall.utils.lang.HalfFloat; import hivemall.utils.math.MathUtils; import java.util.Arrays; + import javax.annotation.Nonnull; import org.apache.commons.logging.Log; @@ -38,6 +41,13 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel { private short[] weights; private short[] covars; + // optional value for adagrad + private float[] sum_of_squared_gradients; + // optional value for adadelta + private float[] sum_of_squared_delta_x; + // optional value for adagrad+rda + private float[] sum_of_gradients; + // optional value for MIX private short[] clocks; private byte[] deltaUpdates; @@ -58,6 +68,9 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel { } else { this.covars = null; } + this.sum_of_squared_gradients = null; + this.sum_of_squared_delta_x = null; + this.sum_of_gradients = null; this.clocks = null; this.deltaUpdates = null; } @@ -73,6 +86,20 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel { } @Override + public void configureParams(boolean sum_of_squared_gradients, boolean sum_of_squared_delta_x, + boolean sum_of_gradients) { + if (sum_of_squared_gradients) { + this.sum_of_squared_gradients = new float[size]; + } + if (sum_of_squared_delta_x) { + this.sum_of_squared_delta_x = new float[size]; + } + if (sum_of_gradients) { + this.sum_of_gradients = new float[size]; + } + } + + @Override public void configureClock() { if (clocks == null) { this.clocks = new short[size]; @@ -99,11 +126,8 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel { return HalfFloat.halfFloatToFloat(covars[i]); } - private void _setWeight(final int i, final float v) { - if(Math.abs(v) >= HalfFloat.MAX_FLOAT) { - throw new IllegalArgumentException("Acceptable maximum weight is " - + HalfFloat.MAX_FLOAT + ": " + v); - } + private void setWeight(final int i, final float v) { + HalfFloat.checkRange(v); weights[i] = HalfFloat.floatToHalfFloat(v); } @@ -125,7 +149,16 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel { this.covars = Arrays.copyOf(covars, newSize); Arrays.fill(covars, oldSize, newSize, HalfFloat.ONE); } - if(clocks != null) { + if (sum_of_squared_gradients != null) { + this.sum_of_squared_gradients = Arrays.copyOf(sum_of_squared_gradients, newSize); + } + if (sum_of_squared_delta_x != null) { + this.sum_of_squared_delta_x = Arrays.copyOf(sum_of_squared_delta_x, newSize); + } + if (sum_of_gradients != null) { + this.sum_of_gradients = Arrays.copyOf(sum_of_gradients, newSize); + } + if (clocks != null) { this.clocks = Arrays.copyOf(clocks, newSize); this.deltaUpdates = Arrays.copyOf(deltaUpdates, newSize); } @@ -139,8 +172,17 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel { if (i >= size) { return null; } - - if(covars != null) { + if (sum_of_squared_gradients != null) { + if (sum_of_squared_delta_x != null) { + return (T) new WeightValueParamsF2(getWeight(i), sum_of_squared_gradients[i], + sum_of_squared_delta_x[i]); + } else if (sum_of_gradients != null) { + return (T) new WeightValueParamsF2(getWeight(i), sum_of_squared_gradients[i], + sum_of_gradients[i]); + } else { + return (T) new WeightValueParamsF1(getWeight(i), sum_of_squared_gradients[i]); + } + } else if (covars != null) { return (T) new WeightValueWithCovar(getWeight(i), getCovar(i)); } else { return (T) new WeightValue(getWeight(i)); @@ -152,13 +194,22 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel { int i = HiveUtils.parseInt(feature); ensureCapacity(i); float weight = value.get(); - _setWeight(i, weight); + setWeight(i, weight); float covar = 1.f; boolean hasCovar = value.hasCovariance(); if (hasCovar) { covar = value.getCovariance(); setCovar(i, covar); } + if (sum_of_squared_gradients != null) { + sum_of_squared_gradients[i] = value.getSumOfSquaredGradients(); + } + if (sum_of_squared_delta_x != null) { + sum_of_squared_delta_x[i] = value.getSumOfSquaredDeltaX(); + } + if (sum_of_gradients != null) { + sum_of_gradients[i] = value.getSumOfGradients(); + } short clock = 0; int delta = 0; if (clocks != null && value.isTouched()) { @@ -178,10 +229,19 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel { if (i >= size) { return; } - _setWeight(i, 0.f); - if(covars != null) { + setWeight(i, 0.f); + if (covars != null) { setCovar(i, 1.f); } + if (sum_of_squared_gradients != null) { + sum_of_squared_gradients[i] = 0.f; + } + if (sum_of_squared_delta_x != null) { + sum_of_squared_delta_x[i] = 0.f; + } + if (sum_of_gradients != null) { + sum_of_gradients[i] = 0.f; + } // avoid clock/delta } @@ -195,10 +255,8 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel { } @Override - public void setWeight(Object feature, float value) { - int i = HiveUtils.parseInt(feature); - ensureCapacity(i); - _setWeight(i, value); + public void setWeight(@Nonnull Object feature, float value) { + throw new UnsupportedOperationException(); } @Override @@ -214,7 +272,7 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel { protected void _set(Object feature, float weight, short clock) { int i = ((Integer) feature).intValue(); ensureCapacity(i); - _setWeight(i, weight); + setWeight(i, weight); clocks[i] = clock; deltaUpdates[i] = 0; } @@ -223,7 +281,7 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel { protected void _set(Object feature, float weight, float covar, short clock) { int i = ((Integer) feature).intValue(); ensureCapacity(i); - _setWeight(i, weight); + setWeight(i, weight); setCovar(i, covar); clocks[i] = clock; deltaUpdates[i] = 0; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/model/SparseModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/SparseModel.java b/core/src/main/java/hivemall/model/SparseModel.java index bab982f..f4c4c55 100644 --- a/core/src/main/java/hivemall/model/SparseModel.java +++ b/core/src/main/java/hivemall/model/SparseModel.java @@ -36,10 +36,6 @@ public final class SparseModel extends AbstractPredictionModel { private final boolean hasCovar; private boolean clockEnabled; - public SparseModel(int size) { - this(size, false); - } - public SparseModel(int size, boolean hasCovar) { super(); this.weights = new OpenHashMap<Object, IWeightValue>(size); @@ -58,6 +54,10 @@ public final class SparseModel extends AbstractPredictionModel { } @Override + public void configureParams(boolean sum_of_squared_gradients, boolean sum_of_squared_delta_x, + boolean sum_of_gradients) {} + + @Override public void configureClock() { this.clockEnabled = true; } @@ -131,15 +131,8 @@ public final class SparseModel extends AbstractPredictionModel { } @Override - public void setWeight(Object feature, float value) { - if(weights.containsKey(feature)) { - IWeightValue weight = weights.get(feature); - weight.set(value); - } else { - IWeightValue weight = new WeightValue(value); - weight.set(value); - weights.put(feature, weight); - } + public void setWeight(@Nonnull Object feature, float value) { + throw new UnsupportedOperationException(); } @Override http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/model/SynchronizedModelWrapper.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/SynchronizedModelWrapper.java b/core/src/main/java/hivemall/model/SynchronizedModelWrapper.java index 87e89b6..dcb0bc9 100644 --- a/core/src/main/java/hivemall/model/SynchronizedModelWrapper.java +++ b/core/src/main/java/hivemall/model/SynchronizedModelWrapper.java @@ -63,6 +63,12 @@ public final class SynchronizedModelWrapper implements PredictionModel { } @Override + public void configureParams(boolean sum_of_squared_gradients, boolean sum_of_squared_delta_x, + boolean sum_of_gradients) { + model.configureParams(sum_of_squared_gradients, sum_of_squared_delta_x, sum_of_gradients); + } + + @Override public void configureClock() { model.configureClock(); } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/regression/AROWRegressionUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/regression/AROWRegressionUDTF.java b/core/src/main/java/hivemall/regression/AROWRegressionUDTF.java index 0c964c8..0503145 100644 --- a/core/src/main/java/hivemall/regression/AROWRegressionUDTF.java +++ b/core/src/main/java/hivemall/regression/AROWRegressionUDTF.java @@ -18,12 +18,12 @@ */ package hivemall.regression; -import hivemall.optimizer.LossFunctions; import hivemall.common.OnlineVariance; import hivemall.model.FeatureValue; import hivemall.model.IWeightValue; import hivemall.model.PredictionResult; import hivemall.model.WeightValue.WeightValueWithCovar; +import hivemall.optimizer.LossFunctions; import javax.annotation.Nonnull; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/regression/AdaDeltaUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/regression/AdaDeltaUDTF.java b/core/src/main/java/hivemall/regression/AdaDeltaUDTF.java index 50dc9b5..93453c1 100644 --- a/core/src/main/java/hivemall/regression/AdaDeltaUDTF.java +++ b/core/src/main/java/hivemall/regression/AdaDeltaUDTF.java @@ -18,14 +18,126 @@ */ package hivemall.regression; +import hivemall.model.FeatureValue; +import hivemall.model.IWeightValue; +import hivemall.model.WeightValue.WeightValueParamsF2; +import hivemall.optimizer.LossFunctions; +import hivemall.utils.lang.Primitives; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Options; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; + /** * ADADELTA: AN ADAPTIVE LEARNING RATE METHOD. + * + * @deprecated Use {@link hivemall.regression.GeneralRegressionUDTF} instead */ @Deprecated -public final class AdaDeltaUDTF extends GeneralRegressionUDTF { +@Description( + name = "train_adadelta_regr", + value = "_FUNC_(array<int|bigint|string> features, float target [, constant string options])" + + " - Returns a relation consists of <{int|bigint|string} feature, float weight>") +public final class AdaDeltaUDTF extends RegressionBaseUDTF { + + private float decay; + private float eps; + private float scaling; + + @Override + public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { + final int numArgs = argOIs.length; + if (numArgs != 2 && numArgs != 3) { + throw new UDFArgumentException( + "AdaDeltaUDTF takes 2 or 3 arguments: List<Text|Int|BitInt> features, float target [, constant string options]"); + } + + StructObjectInspector oi = super.initialize(argOIs); + model.configureParams(true, true, false); + return oi; + } + + @Override + protected Options getOptions() { + Options opts = super.getOptions(); + opts.addOption("rho", "decay", true, "Decay rate [default 0.95]"); + opts.addOption("eps", true, "A constant used in the denominator of AdaGrad [default 1e-6]"); + opts.addOption("scale", true, + "Internal scaling/descaling factor for cumulative weights [100]"); + return opts; + } + + @Override + protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException { + CommandLine cl = super.processOptions(argOIs); + if (cl == null) { + this.decay = 0.95f; + this.eps = 1e-6f; + this.scaling = 100f; + } else { + this.decay = Primitives.parseFloat(cl.getOptionValue("decay"), 0.95f); + this.eps = Primitives.parseFloat(cl.getOptionValue("eps"), 1E-6f); + this.scaling = Primitives.parseFloat(cl.getOptionValue("scale"), 100f); + } + return cl; + } + + @Override + protected final void checkTargetValue(final float target) throws UDFArgumentException { + if (target < 0.f || target > 1.f) { + throw new UDFArgumentException("target must be in range 0 to 1: " + target); + } + } + + @Override + protected void update(@Nonnull final FeatureValue[] features, float target, float predicted) { + float gradient = LossFunctions.logisticLoss(target, predicted); + onlineUpdate(features, gradient); + } + + @Override + protected void onlineUpdate(@Nonnull final FeatureValue[] features, float gradient) { + final float g_g = gradient * (gradient / scaling); + + for (FeatureValue f : features) {// w[i] += y * x[i] + if (f == null) { + continue; + } + Object x = f.getFeature(); + float xi = f.getValueAsFloat(); + + IWeightValue old_w = model.get(x); + IWeightValue new_w = getNewWeight(old_w, xi, gradient, g_g); + model.set(x, new_w); + } + } + + @Nonnull + protected IWeightValue getNewWeight(@Nullable final IWeightValue old, final float xi, + final float gradient, final float g_g) { + float old_w = 0.f; + float old_scaled_sum_sqgrad = 0.f; + float old_sum_squared_delta_x = 0.f; + if (old != null) { + old_w = old.get(); + old_scaled_sum_sqgrad = old.getSumOfSquaredGradients(); + old_sum_squared_delta_x = old.getSumOfSquaredDeltaX(); + } - public AdaDeltaUDTF() { - optimizerOptions.put("optimizer", "AdaDelta"); + float new_scaled_sum_sq_grad = (decay * old_scaled_sum_sqgrad) + ((1.f - decay) * g_g); + float dx = (float) Math.sqrt((old_sum_squared_delta_x + eps) + / (old_scaled_sum_sqgrad * scaling + eps)) + * gradient; + float new_sum_squared_delta_x = (decay * old_sum_squared_delta_x) + + ((1.f - decay) * dx * dx); + float new_w = old_w + (dx * xi); + return new WeightValueParamsF2(new_w, new_scaled_sum_sq_grad, new_sum_squared_delta_x); } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/regression/AdaGradUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/regression/AdaGradUDTF.java b/core/src/main/java/hivemall/regression/AdaGradUDTF.java index 4b5f019..87188fc 100644 --- a/core/src/main/java/hivemall/regression/AdaGradUDTF.java +++ b/core/src/main/java/hivemall/regression/AdaGradUDTF.java @@ -18,14 +18,127 @@ */ package hivemall.regression; +import hivemall.model.FeatureValue; +import hivemall.model.IWeightValue; +import hivemall.model.WeightValue.WeightValueParamsF1; +import hivemall.optimizer.LossFunctions; +import hivemall.utils.lang.Primitives; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Options; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; + /** * ADAGRAD algorithm with element-wise adaptive learning rates. + * + * @deprecated Use {@link hivemall.regression.GeneralRegressionUDTF} instead */ @Deprecated -public final class AdaGradUDTF extends GeneralRegressionUDTF { +@Description( + name = "train_adagrad_regr", + value = "_FUNC_(array<int|bigint|string> features, float target [, constant string options])" + + " - Returns a relation consists of <{int|bigint|string} feature, float weight>") +public final class AdaGradUDTF extends RegressionBaseUDTF { + + private float eta; + private float eps; + private float scaling; + + @Override + public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { + final int numArgs = argOIs.length; + if (numArgs != 2 && numArgs != 3) { + throw new UDFArgumentException( + "_FUNC_ takes 2 or 3 arguments: List<Text|Int|BitInt> features, float target [, constant string options]"); + } + + StructObjectInspector oi = super.initialize(argOIs); + model.configureParams(true, false, false); + return oi; + } + + @Override + protected Options getOptions() { + Options opts = super.getOptions(); + opts.addOption("eta", "eta0", true, "The initial learning rate [default 1.0]"); + opts.addOption("eps", true, "A constant used in the denominator of AdaGrad [default 1.0]"); + opts.addOption("scale", true, + "Internal scaling/descaling factor for cumulative weights [100]"); + return opts; + } + + @Override + protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException { + CommandLine cl = super.processOptions(argOIs); + if (cl == null) { + this.eta = 1.f; + this.eps = 1.f; + this.scaling = 100f; + } else { + this.eta = Primitives.parseFloat(cl.getOptionValue("eta"), 1.f); + this.eps = Primitives.parseFloat(cl.getOptionValue("eps"), 1.f); + this.scaling = Primitives.parseFloat(cl.getOptionValue("scale"), 100f); + } + return cl; + } + + @Override + protected final void checkTargetValue(final float target) throws UDFArgumentException { + if (target < 0.f || target > 1.f) { + throw new UDFArgumentException("target must be in range 0 to 1: " + target); + } + } + + @Override + protected void update(@Nonnull final FeatureValue[] features, float target, float predicted) { + float gradient = LossFunctions.logisticLoss(target, predicted); + onlineUpdate(features, gradient); + } + + @Override + protected void onlineUpdate(@Nonnull final FeatureValue[] features, float gradient) { + final float g_g = gradient * (gradient / scaling); + + for (FeatureValue f : features) {// w[i] += y * x[i] + if (f == null) { + continue; + } + Object x = f.getFeature(); + float xi = f.getValueAsFloat(); + + IWeightValue old_w = model.get(x); + IWeightValue new_w = getNewWeight(old_w, xi, gradient, g_g); + model.set(x, new_w); + } + } + + @Nonnull + protected IWeightValue getNewWeight(@Nullable final IWeightValue old, final float xi, + final float gradient, final float g_g) { + float old_w = 0.f; + float scaled_sum_sqgrad = 0.f; + + if (old != null) { + old_w = old.get(); + scaled_sum_sqgrad = old.getSumOfSquaredGradients(); + } + scaled_sum_sqgrad += g_g; + + float coeff = eta(scaled_sum_sqgrad) * gradient; + float new_w = old_w + (coeff * xi); + return new WeightValueParamsF1(new_w, scaled_sum_sqgrad); + } - public AdaGradUDTF() { - optimizerOptions.put("optimizer", "AdaGrad"); + protected float eta(final double scaledSumOfSquaredGradients) { + double sumOfSquaredGradients = scaledSumOfSquaredGradients * scaling; + //return eta / (float) Math.sqrt(sumOfSquaredGradients); + return eta / (float) Math.sqrt(eps + sumOfSquaredGradients); // always less than eta0 } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java b/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java index 2a8b543..21a784e 100644 --- a/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java +++ b/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java @@ -40,6 +40,7 @@ public class GeneralRegressionUDTF extends RegressionBaseUDTF { protected final Map<String, String> optimizerOptions; public GeneralRegressionUDTF() { + super(true); // This enables new model interfaces this.optimizerOptions = new HashMap<String, String>(); // Set default values optimizerOptions.put("optimizer", "adadelta"); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/regression/LogressUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/regression/LogressUDTF.java b/core/src/main/java/hivemall/regression/LogressUDTF.java index ea05da3..78e617d 100644 --- a/core/src/main/java/hivemall/regression/LogressUDTF.java +++ b/core/src/main/java/hivemall/regression/LogressUDTF.java @@ -18,12 +18,69 @@ */ package hivemall.regression; +import hivemall.optimizer.EtaEstimator; +import hivemall.optimizer.LossFunctions; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Options; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; + +/** + * @deprecated Use {@link hivemall.regression.GeneralRegressionUDTF} instead + */ @Deprecated -public final class LogressUDTF extends GeneralRegressionUDTF { +@Description( + name = "logress", + value = "_FUNC_(array<int|bigint|string> features, float target [, constant string options])" + + " - Returns a relation consists of <{int|bigint|string} feature, float weight>") +public final class LogressUDTF extends RegressionBaseUDTF { + + private EtaEstimator etaEstimator; + + @Override + public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { + final int numArgs = argOIs.length; + if (numArgs != 2 && numArgs != 3) { + throw new UDFArgumentException( + "LogressUDTF takes 2 or 3 arguments: List<Text|Int|BitInt> features, float target [, constant string options]"); + } + + return super.initialize(argOIs); + } + + @Override + protected Options getOptions() { + Options opts = super.getOptions(); + opts.addOption("t", "total_steps", true, "a total of n_samples * epochs time steps"); + opts.addOption("power_t", true, + "The exponent for inverse scaling learning rate [default 0.1]"); + opts.addOption("eta0", true, "The initial learning rate [default 0.1]"); + return opts; + } + + @Override + protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException { + CommandLine cl = super.processOptions(argOIs); + + this.etaEstimator = EtaEstimator.get(cl); + return cl; + } + + @Override + protected void checkTargetValue(final float target) throws UDFArgumentException { + if (target < 0.f || target > 1.f) { + throw new UDFArgumentException("target must be in range 0 to 1: " + target); + } + } - public LogressUDTF() { - optimizerOptions.put("optimizer", "SGD"); - optimizerOptions.put("eta", "fixed"); + @Override + protected float computeGradient(final float target, final float predicted) { + float eta = etaEstimator.eta(count); + float gradient = LossFunctions.logisticLoss(target, predicted); + return eta * gradient; } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/regression/PassiveAggressiveRegressionUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/regression/PassiveAggressiveRegressionUDTF.java b/core/src/main/java/hivemall/regression/PassiveAggressiveRegressionUDTF.java index e1afe2f..3de56fd 100644 --- a/core/src/main/java/hivemall/regression/PassiveAggressiveRegressionUDTF.java +++ b/core/src/main/java/hivemall/regression/PassiveAggressiveRegressionUDTF.java @@ -18,10 +18,10 @@ */ package hivemall.regression; -import hivemall.optimizer.LossFunctions; import hivemall.common.OnlineVariance; import hivemall.model.FeatureValue; import hivemall.model.PredictionResult; +import hivemall.optimizer.LossFunctions; import javax.annotation.Nonnull; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java b/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java index 7dc8538..24b0556 100644 --- a/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java +++ b/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java @@ -72,6 +72,16 @@ public abstract class RegressionBaseUDTF extends LearnerBaseUDTF { protected transient Map<Object, FloatAccumulator> accumulated; protected int sampled; + private boolean enableNewModel; + + public RegressionBaseUDTF() { + this.enableNewModel = false; + } + + public RegressionBaseUDTF(boolean enableNewModel) { + this.enableNewModel = enableNewModel; + } + @Override public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { if (argOIs.length < 2) { @@ -85,7 +95,7 @@ public abstract class RegressionBaseUDTF extends LearnerBaseUDTF { PrimitiveObjectInspector featureOutputOI = dense_model ? PrimitiveObjectInspectorFactory.javaIntObjectInspector : featureInputOI; - this.model = createModel(); + this.model = enableNewModel? createNewModel(null) : createModel(); if (preloadedModelFile != null) { loadPredictionModel(model, preloadedModelFile, featureOutputOI); } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/test/java/hivemall/model/NewSpaceEfficientNewDenseModelTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/model/NewSpaceEfficientNewDenseModelTest.java b/core/src/test/java/hivemall/model/NewSpaceEfficientNewDenseModelTest.java new file mode 100644 index 0000000..dd9c4ec --- /dev/null +++ b/core/src/test/java/hivemall/model/NewSpaceEfficientNewDenseModelTest.java @@ -0,0 +1,60 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2015 Makoto YUI + * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package hivemall.model; + +import static org.junit.Assert.assertEquals; +import hivemall.utils.collections.IMapIterator; +import hivemall.utils.lang.HalfFloat; + +import java.util.Random; + +import org.junit.Test; + +public class NewSpaceEfficientNewDenseModelTest { + + @Test + public void testGetSet() { + final int size = 1 << 12; + + final NewSpaceEfficientDenseModel model1 = new NewSpaceEfficientDenseModel(size); + //model1.configureClock(); + final NewDenseModel model2 = new NewDenseModel(size); + //model2.configureClock(); + + final Random rand = new Random(); + for (int t = 0; t < 1000; t++) { + int i = rand.nextInt(size); + float f = HalfFloat.MAX_FLOAT * rand.nextFloat(); + IWeightValue w = new WeightValue(f); + model1.set(i, w); + model2.set(i, w); + } + + assertEquals(model2.size(), model1.size()); + + IMapIterator<Integer, IWeightValue> itor = model1.entries(); + while (itor.next() != -1) { + int k = itor.getKey(); + float expected = itor.getValue().get(); + float actual = model2.getWeight(k); + assertEquals(expected, actual, 32f); + } + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/test/java/hivemall/model/SpaceEfficientDenseModelTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/model/SpaceEfficientDenseModelTest.java b/core/src/test/java/hivemall/model/SpaceEfficientDenseModelTest.java deleted file mode 100644 index e3a1ed4..0000000 --- a/core/src/test/java/hivemall/model/SpaceEfficientDenseModelTest.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Hivemall: Hive scalable Machine Learning Library - * - * Copyright (C) 2015 Makoto YUI - * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST) - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package hivemall.model; - -import static org.junit.Assert.assertEquals; -import hivemall.utils.collections.IMapIterator; -import hivemall.utils.lang.HalfFloat; - -import java.util.Random; - -import org.junit.Test; - -public class SpaceEfficientDenseModelTest { - - @Test - public void testGetSet() { - final int size = 1 << 12; - - final SpaceEfficientDenseModel model1 = new SpaceEfficientDenseModel(size); - //model1.configureClock(); - final DenseModel model2 = new DenseModel(size); - //model2.configureClock(); - - final Random rand = new Random(); - for (int t = 0; t < 1000; t++) { - int i = rand.nextInt(size); - float f = HalfFloat.MAX_FLOAT * rand.nextFloat(); - IWeightValue w = new WeightValue(f); - model1.set(i, w); - model2.set(i, w); - } - - assertEquals(model2.size(), model1.size()); - - IMapIterator<Integer, IWeightValue> itor = model1.entries(); - while (itor.next() != -1) { - int k = itor.getKey(); - float expected = itor.getValue().get(); - float actual = model2.getWeight(k); - assertEquals(expected, actual, 32f); - } - } - -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/mixserv/src/test/java/hivemall/mix/server/MixServerTest.java ---------------------------------------------------------------------- diff --git a/mixserv/src/test/java/hivemall/mix/server/MixServerTest.java b/mixserv/src/test/java/hivemall/mix/server/MixServerTest.java index 38792d8..ec6d556 100644 --- a/mixserv/src/test/java/hivemall/mix/server/MixServerTest.java +++ b/mixserv/src/test/java/hivemall/mix/server/MixServerTest.java @@ -18,9 +18,9 @@ */ package hivemall.mix.server; -import hivemall.model.DenseModel; +import hivemall.model.NewDenseModel; import hivemall.model.PredictionModel; -import hivemall.model.SparseModel; +import hivemall.model.NewSparseModel; import hivemall.model.WeightValue; import hivemall.mix.MixMessage.MixEventName; import hivemall.mix.client.MixClient; @@ -55,7 +55,7 @@ public class MixServerTest extends HivemallTestBase { waitForState(server, ServerState.RUNNING); - PredictionModel model = new DenseModel(16777216); + PredictionModel model = new NewDenseModel(16777216); model.configureClock(); MixClient client = null; try { @@ -93,7 +93,7 @@ public class MixServerTest extends HivemallTestBase { waitForState(server, ServerState.RUNNING); - PredictionModel model = new DenseModel(16777216); + PredictionModel model = new NewDenseModel(16777216); model.configureClock(); MixClient client = null; try { @@ -151,7 +151,7 @@ public class MixServerTest extends HivemallTestBase { } private static void invokeClient(String groupId, int serverPort) throws InterruptedException { - PredictionModel model = new DenseModel(16777216); + PredictionModel model = new NewDenseModel(16777216); model.configureClock(); MixClient client = null; try { @@ -298,8 +298,8 @@ public class MixServerTest extends HivemallTestBase { private static void invokeClient01(String groupId, int serverPort, boolean denseModel, boolean cancelMix) throws InterruptedException { - PredictionModel model = denseModel ? new DenseModel(100) - : new SparseModel(100, false); + PredictionModel model = denseModel ? new NewDenseModel(100) + : new NewSparseModel(100, false); model.configureClock(); MixClient client = null; try { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/spark/spark-2.0/src/test/scala/hivemall/mix/server/MixServerSuite.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.0/src/test/scala/hivemall/mix/server/MixServerSuite.scala b/spark/spark-2.0/src/test/scala/hivemall/mix/server/MixServerSuite.scala index 4fb74f1..c0ee72f 100644 --- a/spark/spark-2.0/src/test/scala/hivemall/mix/server/MixServerSuite.scala +++ b/spark/spark-2.0/src/test/scala/hivemall/mix/server/MixServerSuite.scala @@ -23,7 +23,7 @@ import java.util.logging.Logger import org.scalatest.{BeforeAndAfter, FunSuite} -import hivemall.model.{DenseModel, PredictionModel, WeightValue} +import hivemall.model.{NewDenseModel, PredictionModel, WeightValue} import hivemall.mix.MixMessage.MixEventName import hivemall.mix.client.MixClient import hivemall.mix.server.MixServer.ServerState @@ -95,7 +95,7 @@ class MixServerSuite extends FunSuite with BeforeAndAfter { ignore(testName) { val clients = Executors.newCachedThreadPool() val numClients = nclient - val models = (0 until numClients).map(i => new DenseModel(ndims, false)) + val models = (0 until numClients).map(i => new NewDenseModel(ndims, false)) (0 until numClients).map { i => clients.submit(new Runnable() { override def run(): Unit = {
