Repository: incubator-hivemall Updated Branches: refs/heads/master 1db535876 -> 5e27993b6
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/optimizer/Regularization.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/optimizer/Regularization.java b/core/src/main/java/hivemall/optimizer/Regularization.java new file mode 100644 index 0000000..4939f60 --- /dev/null +++ b/core/src/main/java/hivemall/optimizer/Regularization.java @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.optimizer; + +import javax.annotation.Nonnull; +import java.util.Map; + +public abstract class Regularization { + /** the default regularization term 0.0001 */ + public static final float DEFAULT_LAMBDA = 0.0001f; + + protected final float lambda; + + public Regularization(@Nonnull Map<String, String> options) { + float lambda = DEFAULT_LAMBDA; + if (options.containsKey("lambda")) { + lambda = Float.parseFloat(options.get("lambda")); + } + this.lambda = lambda; + } + + public float regularize(float weight, float gradient) { + return gradient + lambda * getRegularizer(weight); + } + + abstract float getRegularizer(float weight); + + public static final class PassThrough extends Regularization { + + public PassThrough(final Map<String, String> options) { + super(options); + } + + @Override + public float getRegularizer(float weight) { + return 0.f; + } + + } + + public static final class L1 extends Regularization { + + public L1(Map<String, String> options) { + super(options); + } + + @Override + public float getRegularizer(float weight) { + return (weight > 0.f ? 1.f : -1.f); + } + + } + + public static final class L2 extends Regularization { + + public L2(final Map<String, String> options) { + super(options); + } + + @Override + public float getRegularizer(float weight) { + return weight; + } + + } + + public static final class ElasticNet extends Regularization { + public static final float DEFAULT_L1_RATIO = 0.5f; + + protected final L1 l1; + protected final L2 l2; + + protected final float l1Ratio; + + public ElasticNet(Map<String, String> options) { + super(options); + + this.l1 = new L1(options); + this.l2 = new L2(options); + + float l1Ratio = DEFAULT_L1_RATIO; + if (options.containsKey("l1_ratio")) { + l1Ratio = Float.parseFloat(options.get("l1_ratio")); + if (l1Ratio < 0.f || l1Ratio > 1.f) { + throw new IllegalArgumentException("L1 ratio should be in [0.0, 1.0], but got " + + l1Ratio); + } + } + this.l1Ratio = l1Ratio; + } + + @Override + public float getRegularizer(float weight) { + return l1Ratio * l1.getRegularizer(weight) + (1.f - l1Ratio) + * l2.getRegularizer(weight); + } + } + + @Nonnull + public static Regularization get(@Nonnull final Map<String, String> options) + throws IllegalArgumentException { + final String regName = options.get("regularization"); + if (regName == null) { + return new PassThrough(options); + } + + if (regName.toLowerCase().equals("no")) { + return new PassThrough(options); + } else if (regName.toLowerCase().equals("l1")) { + return new L1(options); + } else if (regName.toLowerCase().equals("l2")) { + return new L2(options); + } else if (regName.toLowerCase().equals("elasticnet")) { + return new ElasticNet(options); + } else if (regName.toLowerCase().equals("rda")) { + // Return `PassThrough` because we need special handling for RDA. + // See an implementation of `Optimizer#RDA`. + return new PassThrough(options); + } else { + throw new IllegalArgumentException("Unsupported regularization name: " + regName); + } + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java b/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java new file mode 100644 index 0000000..4a003f3 --- /dev/null +++ b/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.optimizer; + +import hivemall.model.IWeightValue; +import hivemall.model.WeightValue; +import hivemall.optimizer.Optimizer.OptimizerBase; +import hivemall.utils.collections.maps.OpenHashMap; + +import java.util.Map; + +import javax.annotation.Nonnull; +import javax.annotation.concurrent.NotThreadSafe; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +public final class SparseOptimizerFactory { + private static final Log LOG = LogFactory.getLog(SparseOptimizerFactory.class); + + @Nonnull + public static Optimizer create(int ndims, @Nonnull Map<String, String> options) { + final String optimizerName = options.get("optimizer"); + if (optimizerName != null) { + OptimizerBase optimizerImpl; + if (optimizerName.toLowerCase().equals("sgd")) { + optimizerImpl = new Optimizer.SGD(options); + } else if (optimizerName.toLowerCase().equals("adadelta")) { + optimizerImpl = new AdaDelta(ndims, options); + } else if (optimizerName.toLowerCase().equals("adagrad")) { + optimizerImpl = new AdaGrad(ndims, options); + } else if (optimizerName.toLowerCase().equals("adam")) { + optimizerImpl = new Adam(ndims, options); + } else { + throw new IllegalArgumentException("Unsupported optimizer name: " + optimizerName); + } + + // If a regularization type is "RDA", wrap the optimizer with `Optimizer#RDA`. + if (options.get("regularization") != null + && options.get("regularization").toLowerCase().equals("rda")) { + optimizerImpl = new AdagradRDA(ndims, optimizerImpl, options); + } + + if (LOG.isInfoEnabled()) { + LOG.info("set " + optimizerImpl.getClass().getSimpleName() + " as an optimizer: " + + options); + } + + return optimizerImpl; + } + throw new IllegalArgumentException("`optimizer` not defined"); + } + + @NotThreadSafe + static final class AdaDelta extends Optimizer.AdaDelta { + + @Nonnull + private final OpenHashMap<Object, IWeightValue> auxWeights; + + public AdaDelta(int size, Map<String, String> options) { + super(options); + this.auxWeights = new OpenHashMap<Object, IWeightValue>(size); + } + + @Override + public float update(@Nonnull Object feature, float weight, float gradient) { + IWeightValue auxWeight; + if (auxWeights.containsKey(feature)) { + auxWeight = auxWeights.get(feature); + auxWeight.set(weight); + } else { + auxWeight = new WeightValue.WeightValueParamsF2(weight, 0.f, 0.f); + auxWeights.put(feature, auxWeight); + } + update(auxWeight, gradient); + return auxWeight.get(); + } + + } + + @NotThreadSafe + static final class AdaGrad extends Optimizer.AdaGrad { + + @Nonnull + private final OpenHashMap<Object, IWeightValue> auxWeights; + + public AdaGrad(int size, Map<String, String> options) { + super(options); + this.auxWeights = new OpenHashMap<Object, IWeightValue>(size); + } + + @Override + public float update(@Nonnull Object feature, float weight, float gradient) { + IWeightValue auxWeight; + if (auxWeights.containsKey(feature)) { + auxWeight = auxWeights.get(feature); + auxWeight.set(weight); + } else { + auxWeight = new WeightValue.WeightValueParamsF2(weight, 0.f, 0.f); + auxWeights.put(feature, auxWeight); + } + update(auxWeight, gradient); + return auxWeight.get(); + } + + } + + @NotThreadSafe + static final class Adam extends Optimizer.Adam { + + @Nonnull + private final OpenHashMap<Object, IWeightValue> auxWeights; + + public Adam(int size, Map<String, String> options) { + super(options); + this.auxWeights = new OpenHashMap<Object, IWeightValue>(size); + } + + @Override + public float update(@Nonnull Object feature, float weight, float gradient) { + IWeightValue auxWeight; + if (auxWeights.containsKey(feature)) { + auxWeight = auxWeights.get(feature); + auxWeight.set(weight); + } else { + auxWeight = new WeightValue.WeightValueParamsF2(weight, 0.f, 0.f); + auxWeights.put(feature, auxWeight); + } + update(auxWeight, gradient); + return auxWeight.get(); + } + + } + + @NotThreadSafe + static final class AdagradRDA extends Optimizer.AdagradRDA { + + @Nonnull + private final OpenHashMap<Object, IWeightValue> auxWeights; + + public AdagradRDA(int size, OptimizerBase optimizerImpl, Map<String, String> options) { + super(optimizerImpl, options); + this.auxWeights = new OpenHashMap<Object, IWeightValue>(size); + } + + @Override + public float update(@Nonnull Object feature, float weight, float gradient) { + IWeightValue auxWeight; + if (auxWeights.containsKey(feature)) { + auxWeight = auxWeights.get(feature); + auxWeight.set(weight); + } else { + auxWeight = new WeightValue.WeightValueParamsF2(weight, 0.f, 0.f); + auxWeights.put(feature, auxWeight); + } + update(auxWeight, gradient); + return auxWeight.get(); + } + + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/regression/AROWRegressionUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/regression/AROWRegressionUDTF.java b/core/src/main/java/hivemall/regression/AROWRegressionUDTF.java index ac17e8b..3c40c8f 100644 --- a/core/src/main/java/hivemall/regression/AROWRegressionUDTF.java +++ b/core/src/main/java/hivemall/regression/AROWRegressionUDTF.java @@ -18,12 +18,12 @@ */ package hivemall.regression; -import hivemall.common.LossFunctions; import hivemall.common.OnlineVariance; import hivemall.model.FeatureValue; import hivemall.model.IWeightValue; import hivemall.model.PredictionResult; import hivemall.model.WeightValue.WeightValueWithCovar; +import hivemall.optimizer.LossFunctions; import javax.annotation.Nonnull; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/regression/AdaDeltaUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/regression/AdaDeltaUDTF.java b/core/src/main/java/hivemall/regression/AdaDeltaUDTF.java index 02473d9..68cd35c 100644 --- a/core/src/main/java/hivemall/regression/AdaDeltaUDTF.java +++ b/core/src/main/java/hivemall/regression/AdaDeltaUDTF.java @@ -18,10 +18,10 @@ */ package hivemall.regression; -import hivemall.common.LossFunctions; import hivemall.model.FeatureValue; import hivemall.model.IWeightValue; import hivemall.model.WeightValue.WeightValueParamsF2; +import hivemall.optimizer.LossFunctions; import hivemall.utils.lang.Primitives; import javax.annotation.Nonnull; @@ -36,7 +36,10 @@ import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; /** * ADADELTA: AN ADAPTIVE LEARNING RATE METHOD. + * + * @deprecated Use {@link hivemall.regression.GeneralRegressionUDTF} instead */ +@Deprecated @Description( name = "train_adadelta_regr", value = "_FUNC_(array<int|bigint|string> features, float target [, constant string options])" http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/regression/AdaGradUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/regression/AdaGradUDTF.java b/core/src/main/java/hivemall/regression/AdaGradUDTF.java index 01aec81..237566c 100644 --- a/core/src/main/java/hivemall/regression/AdaGradUDTF.java +++ b/core/src/main/java/hivemall/regression/AdaGradUDTF.java @@ -18,10 +18,10 @@ */ package hivemall.regression; -import hivemall.common.LossFunctions; import hivemall.model.FeatureValue; import hivemall.model.IWeightValue; import hivemall.model.WeightValue.WeightValueParamsF1; +import hivemall.optimizer.LossFunctions; import hivemall.utils.lang.Primitives; import javax.annotation.Nonnull; @@ -36,7 +36,10 @@ import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; /** * ADAGRAD algorithm with element-wise adaptive learning rates. + * + * @deprecated Use {@link hivemall.regression.GeneralRegressionUDTF} instead */ +@Deprecated @Description( name = "train_adagrad_regr", value = "_FUNC_(array<int|bigint|string> features, float target [, constant string options])" http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java b/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java new file mode 100644 index 0000000..5137dd3 --- /dev/null +++ b/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.regression; + +import hivemall.GeneralLearnerBaseUDTF; +import hivemall.annotations.Since; +import hivemall.model.FeatureValue; +import hivemall.optimizer.LossFunctions.LossFunction; +import hivemall.optimizer.LossFunctions.LossType; + +import javax.annotation.Nonnull; + +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; + +/** + * A general regression class with replaceable optimization functions. + */ +@Description(name = "train_regression", + value = "_FUNC_(list<string|int|bigint> features, double label [, const string options])" + + " - Returns a relation consists of <string|int|bigint feature, float weight>", + extended = "Build a prediction model by a generic regressor") +@Since(version = "0.5-rc.1") +public final class GeneralRegressionUDTF extends GeneralLearnerBaseUDTF { + + @Override + protected String getLossOptionDescription() { + return "Loss function [default: SquaredLoss, QuantileLoss, EpsilonInsensitiveLoss, HuberLoss]"; + } + + @Override + protected LossType getDefaultLossType() { + return LossType.SquaredLoss; + } + + @Override + protected void checkLossFunction(@Nonnull LossFunction lossFunction) throws UDFArgumentException { + if (lossFunction.forBinaryClassification()) { + throw new UDFArgumentException("The loss function `" + lossFunction.getType() + + "` is not designed for regression"); + } + } + + @Override + protected void checkTargetValue(float label) throws UDFArgumentException {} + + @Override + protected void train(@Nonnull final FeatureValue[] features, final float target) { + float p = predict(features); + update(features, target, p); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/regression/LogressUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/regression/LogressUDTF.java b/core/src/main/java/hivemall/regression/LogressUDTF.java index 956e8d9..c5c5bae 100644 --- a/core/src/main/java/hivemall/regression/LogressUDTF.java +++ b/core/src/main/java/hivemall/regression/LogressUDTF.java @@ -18,8 +18,8 @@ */ package hivemall.regression; -import hivemall.common.EtaEstimator; -import hivemall.common.LossFunctions; +import hivemall.optimizer.EtaEstimator; +import hivemall.optimizer.LossFunctions; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.Options; @@ -28,6 +28,12 @@ import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +/** + * Logistic regression using SGD. + * + * @deprecated Use {@link hivemall.regression.GeneralRegressionUDTF} instead + */ +@Deprecated @Description( name = "logress", value = "_FUNC_(array<int|bigint|string> features, float target [, constant string options])" @@ -50,7 +56,8 @@ public final class LogressUDTF extends RegressionBaseUDTF { @Override protected Options getOptions() { Options opts = super.getOptions(); - opts.addOption("t", "total_steps", true, "a total of n_samples * epochs time steps"); + opts.addOption("t", "total_steps", true, + "a total of n_samples * epochs time steps [default: 10000]"); opts.addOption("power_t", true, "The exponent for inverse scaling learning rate [default 0.1]"); opts.addOption("eta0", true, "The initial learning rate [default 0.1]"); @@ -73,7 +80,7 @@ public final class LogressUDTF extends RegressionBaseUDTF { } @Override - protected float computeUpdate(final float target, final float predicted) { + protected float computeGradient(final float target, final float predicted) { float eta = etaEstimator.eta(count); float gradient = LossFunctions.logisticLoss(target, predicted); return eta * gradient; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/regression/PassiveAggressiveRegressionUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/regression/PassiveAggressiveRegressionUDTF.java b/core/src/main/java/hivemall/regression/PassiveAggressiveRegressionUDTF.java index a293d51..946a671 100644 --- a/core/src/main/java/hivemall/regression/PassiveAggressiveRegressionUDTF.java +++ b/core/src/main/java/hivemall/regression/PassiveAggressiveRegressionUDTF.java @@ -18,10 +18,10 @@ */ package hivemall.regression; -import hivemall.common.LossFunctions; import hivemall.common.OnlineVariance; import hivemall.model.FeatureValue; import hivemall.model.PredictionResult; +import hivemall.optimizer.LossFunctions; import javax.annotation.Nonnull; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java b/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java index a2ef4f7..eca4cf3 100644 --- a/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java +++ b/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java @@ -52,8 +52,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn import org.apache.hadoop.io.FloatWritable; /** - * The base class for regression algorithms. RegressionBaseUDTF provides general implementation for - * online training and batch training. + * The base class for regression algorithms. RegressionBaseUDTF provides general implementation for online training and batch training. */ public abstract class RegressionBaseUDTF extends LearnerBaseUDTF { private static final Log logger = LogFactory.getLog(RegressionBaseUDTF.class); @@ -70,6 +69,14 @@ public abstract class RegressionBaseUDTF extends LearnerBaseUDTF { protected transient Map<Object, FloatAccumulator> accumulated; protected int sampled; + public RegressionBaseUDTF() { + this(false); + } + + public RegressionBaseUDTF(boolean enableNewModel) { + super(enableNewModel); + } + @Override public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { if (argOIs.length < 2) { @@ -235,7 +242,7 @@ public abstract class RegressionBaseUDTF extends LearnerBaseUDTF { protected void update(@Nonnull final FeatureValue[] features, final float target, final float predicted) { - final float grad = computeUpdate(target, predicted); + final float grad = computeGradient(target, predicted); if (is_mini_batch) { accumulateUpdate(features, grad); @@ -247,15 +254,14 @@ public abstract class RegressionBaseUDTF extends LearnerBaseUDTF { } } - protected float computeUpdate(float target, float predicted) { - throw new IllegalStateException(); - } - - protected IWeightValue getNewWeight(IWeightValue old_w, float delta) { - throw new IllegalStateException(); + /** + * Compute a gradient by using a loss function in derived classes + */ + protected float computeGradient(float target, float predicted) { + throw new UnsupportedOperationException(); } - protected final void accumulateUpdate(@Nonnull final FeatureValue[] features, final float coeff) { + protected void accumulateUpdate(@Nonnull final FeatureValue[] features, final float coeff) { for (int i = 0; i < features.length; i++) { if (features[i] == null) { continue; @@ -275,7 +281,7 @@ public abstract class RegressionBaseUDTF extends LearnerBaseUDTF { sampled++; } - protected final void batchUpdate() { + protected void batchUpdate() { if (accumulated.isEmpty()) { this.sampled = 0; return; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/test/java/hivemall/classifier/GeneralClassifierUDTFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/classifier/GeneralClassifierUDTFTest.java b/core/src/test/java/hivemall/classifier/GeneralClassifierUDTFTest.java new file mode 100644 index 0000000..e558e67 --- /dev/null +++ b/core/src/test/java/hivemall/classifier/GeneralClassifierUDTFTest.java @@ -0,0 +1,261 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.classifier; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.text.ParseException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.StringTokenizer; +import java.util.zip.GZIPInputStream; + +import hivemall.utils.math.MathUtils; + +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; + +import org.junit.Assert; +import org.junit.Test; + +import javax.annotation.Nonnull; + +public class GeneralClassifierUDTFTest { + private static final boolean DEBUG = false; + + @Test(expected = UDFArgumentException.class) + public void testUnsupportedOptimizer() throws Exception { + GeneralClassifierUDTF udtf = new GeneralClassifierUDTF(); + ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector; + ObjectInspector stringOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector; + ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector(stringOI); + ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-opt UnsupportedOpt"); + + udtf.initialize(new ObjectInspector[] {stringListOI, intOI, params}); + } + + @Test(expected = UDFArgumentException.class) + public void testUnsupportedLossFunction() throws Exception { + GeneralClassifierUDTF udtf = new GeneralClassifierUDTF(); + ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector; + ObjectInspector stringOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector; + ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector(stringOI); + ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-loss UnsupportedLoss"); + + udtf.initialize(new ObjectInspector[] {stringListOI, intOI, params}); + } + + @Test(expected = UDFArgumentException.class) + public void testUnsupportedRegularization() throws Exception { + GeneralClassifierUDTF udtf = new GeneralClassifierUDTF(); + ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector; + ObjectInspector stringOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector; + ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector(stringOI); + ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-reg UnsupportedReg"); + + udtf.initialize(new ObjectInspector[] {stringListOI, intOI, params}); + } + + private void run(@Nonnull String options) throws Exception { + println(options); + + ArrayList<List<String>> samplesList = new ArrayList<List<String>>(); + samplesList.add(Arrays.asList("1:-2", "2:-1")); + samplesList.add(Arrays.asList("1:-1", "2:-1")); + samplesList.add(Arrays.asList("1:-1", "2:-2")); + samplesList.add(Arrays.asList("1:1", "2:1")); + samplesList.add(Arrays.asList("1:1", "2:2")); + samplesList.add(Arrays.asList("1:2", "2:1")); + + int[] labels = new int[] {0, 0, 0, 1, 1, 1}; + + int maxIter = 512; + + GeneralClassifierUDTF udtf = new GeneralClassifierUDTF(); + ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector; + ObjectInspector stringOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector; + ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector(stringOI); + ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, options); + + udtf.initialize(new ObjectInspector[] {stringListOI, intOI, params}); + + float cumLossPrev = Float.MAX_VALUE; + float cumLoss = 0.f; + int it = 0; + while ((it < maxIter) && (Math.abs(cumLoss - cumLossPrev) > 1e-3f)) { + cumLossPrev = cumLoss; + udtf.resetCumulativeLoss(); + for (int i = 0, size = samplesList.size(); i < size; i++) { + udtf.process(new Object[] {samplesList.get(i), labels[i]}); + } + cumLoss = udtf.getCumulativeLoss(); + println("Iter: " + ++it + ", Cumulative loss: " + cumLoss); + } + Assert.assertTrue(cumLoss / samplesList.size() < 0.5f); + + int numTests = 0; + int numCorrect = 0; + + for (int i = 0, size = samplesList.size(); i < size; i++) { + int label = labels[i]; + + float score = udtf.predict(udtf.parseFeatures(samplesList.get(i))); + int predicted = score > 0.f ? 1 : 0; + + println("Score: " + score + ", Predicted: " + predicted + ", Actual: " + label); + + if (predicted == label) { + ++numCorrect; + } + ++numTests; + } + + float accuracy = numCorrect / (float) numTests; + println("Accuracy: " + accuracy); + Assert.assertTrue(accuracy == 1.f); + } + + @Test + public void test() throws Exception { + String[] optimizers = new String[] {"SGD", "AdaDelta", "AdaGrad", "Adam"}; + String[] regularizations = new String[] {"NO", "L1", "L2", "ElasticNet", "RDA"}; + String[] lossFunctions = new String[] {"HingeLoss", "LogLoss", "SquaredHingeLoss", + "ModifiedHuberLoss", "SquaredLoss", "QuantileLoss", "EpsilonInsensitiveLoss", + "HuberLoss"}; + + for (String opt : optimizers) { + for (String reg : regularizations) { + if (reg == "RDA" && opt != "AdaGrad") { + continue; + } + + for (String loss : lossFunctions) { + String options = "-opt " + opt + " -reg " + reg + " -loss " + loss; + + // sparse + run(options); + + if (opt != "AdaGrad") { + options += " -mini_batch 2"; + run(options); + } + + // dense + options += " -dense"; + run(options); + } + } + } + } + + @Test + public void testNews20() throws IOException, ParseException, HiveException { + int nIter = 10; + + GeneralClassifierUDTF udtf = new GeneralClassifierUDTF(); + ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector; + ObjectInspector stringOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector; + ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector(stringOI); + ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + "-opt SGD -loss logloss -reg L2 -lambda 0.1"); + + udtf.initialize(new ObjectInspector[] {stringListOI, intOI, params}); + + BufferedReader news20 = readFile("news20-small.binary.gz"); + ArrayList<Integer> labels = new ArrayList<Integer>(); + ArrayList<String> words = new ArrayList<String>(); + ArrayList<ArrayList<String>> wordsList = new ArrayList<ArrayList<String>>(); + String line = news20.readLine(); + while (line != null) { + StringTokenizer tokens = new StringTokenizer(line, " "); + int label = Integer.parseInt(tokens.nextToken()); + while (tokens.hasMoreTokens()) { + words.add(tokens.nextToken()); + } + Assert.assertFalse(words.isEmpty()); + udtf.process(new Object[] {words, label}); + + labels.add(label); + wordsList.add((ArrayList) words.clone()); + + words.clear(); + line = news20.readLine(); + } + news20.close(); + + // perform SGD iterations + for (int it = 1; it < nIter; it++) { + for (int i = 0, size = wordsList.size(); i < size; i++) { + words = wordsList.get(i); + int label = labels.get(i); + udtf.process(new Object[] {words, label}); + } + } + + int numTests = 0; + int numCorrect = 0; + + for (int i = 0, size = wordsList.size(); i < size; i++) { + words = wordsList.get(i); + int label = labels.get(i); + + float score = udtf.predict(udtf.parseFeatures(words)); + int predicted = MathUtils.sign(score); + + println("Score: " + score + ", Predicted: " + predicted + ", Actual: " + label); + + if (predicted == label) { + ++numCorrect; + } + ++numTests; + } + + float accuracy = numCorrect / (float) numTests; + println("Accuracy: " + accuracy); + Assert.assertTrue(accuracy > 0.8f); + } + + private static void println(String msg) { + if (DEBUG) { + System.out.println(msg); + } + } + + @Nonnull + private static BufferedReader readFile(@Nonnull String fileName) throws IOException { + InputStream is = GeneralClassifierUDTFTest.class.getResourceAsStream(fileName); + if (fileName.endsWith(".gz")) { + is = new GZIPInputStream(is); + } + return new BufferedReader(new InputStreamReader(is)); + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/test/java/hivemall/model/NewSpaceEfficientNewDenseModelTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/model/NewSpaceEfficientNewDenseModelTest.java b/core/src/test/java/hivemall/model/NewSpaceEfficientNewDenseModelTest.java new file mode 100644 index 0000000..c892071 --- /dev/null +++ b/core/src/test/java/hivemall/model/NewSpaceEfficientNewDenseModelTest.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.model; + +import static org.junit.Assert.assertEquals; +import hivemall.utils.collections.IMapIterator; +import hivemall.utils.lang.HalfFloat; + +import java.util.Random; + +import org.junit.Test; + +public class NewSpaceEfficientNewDenseModelTest { + + @Test + public void testGetSet() { + final int size = 1 << 12; + + final NewSpaceEfficientDenseModel model1 = new NewSpaceEfficientDenseModel(size); + //model1.configureClock(); + final NewDenseModel model2 = new NewDenseModel(size); + //model2.configureClock(); + + final Random rand = new Random(); + for (int t = 0; t < 1000; t++) { + int i = rand.nextInt(size); + float f = HalfFloat.MAX_FLOAT * rand.nextFloat(); + IWeightValue w = new WeightValue(f); + model1.set(i, w); + model2.set(i, w); + } + + assertEquals(model2.size(), model1.size()); + + IMapIterator<Integer, IWeightValue> itor = model1.entries(); + while (itor.next() != -1) { + int k = itor.getKey(); + float expected = itor.getValue().get(); + float actual = model2.getWeight(k); + assertEquals(expected, actual, 32f); + } + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/test/java/hivemall/model/SpaceEfficientDenseModelTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/model/SpaceEfficientDenseModelTest.java b/core/src/test/java/hivemall/model/SpaceEfficientDenseModelTest.java deleted file mode 100644 index 8106890..0000000 --- a/core/src/test/java/hivemall/model/SpaceEfficientDenseModelTest.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.model; - -import static org.junit.Assert.assertEquals; -import hivemall.utils.collections.IMapIterator; -import hivemall.utils.lang.HalfFloat; - -import java.util.Random; - -import org.junit.Test; - -public class SpaceEfficientDenseModelTest { - - @Test - public void testGetSet() { - final int size = 1 << 12; - - final SpaceEfficientDenseModel model1 = new SpaceEfficientDenseModel(size); - //model1.configureClock(); - final DenseModel model2 = new DenseModel(size); - //model2.configureClock(); - - final Random rand = new Random(); - for (int t = 0; t < 1000; t++) { - int i = rand.nextInt(size); - float f = HalfFloat.MAX_FLOAT * rand.nextFloat(); - IWeightValue w = new WeightValue(f); - model1.set(i, w); - model2.set(i, w); - } - - assertEquals(model2.size(), model1.size()); - - IMapIterator<Integer, IWeightValue> itor = model1.entries(); - while (itor.next() != -1) { - int k = itor.getKey(); - float expected = itor.getValue().get(); - float actual = model2.getWeight(k); - assertEquals(expected, actual, 32f); - } - } - -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/test/java/hivemall/optimizer/OptimizerTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/optimizer/OptimizerTest.java b/core/src/test/java/hivemall/optimizer/OptimizerTest.java new file mode 100644 index 0000000..f54effd --- /dev/null +++ b/core/src/test/java/hivemall/optimizer/OptimizerTest.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.optimizer; + +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashMap; +import java.util.Map; +import java.util.Random; + +public final class OptimizerTest { + + @Test + public void testIllegalOptimizer() { + try { + final Map<String, String> emptyOptions = new HashMap<String, String>(); + DenseOptimizerFactory.create(1024, emptyOptions); + Assert.fail(); + } catch (IllegalArgumentException e) { + // tests passed + } + try { + final Map<String, String> options = new HashMap<String, String>(); + options.put("optimizer", "illegal"); + DenseOptimizerFactory.create(1024, options); + Assert.fail(); + } catch (IllegalArgumentException e) { + // tests passed + } + try { + final Map<String, String> emptyOptions = new HashMap<String, String>(); + SparseOptimizerFactory.create(1024, emptyOptions); + Assert.fail(); + } catch (IllegalArgumentException e) { + // tests passed + } + try { + final Map<String, String> options = new HashMap<String, String>(); + options.put("optimizer", "illegal"); + SparseOptimizerFactory.create(1024, options); + Assert.fail(); + } catch (IllegalArgumentException e) { + // tests passed + } + } + + @Test + public void testOptimizerFactory() { + final Map<String, String> options = new HashMap<String, String>(); + final String[] regTypes = new String[] {"NO", "L1", "L2", "ElasticNet"}; + options.put("optimizer", "SGD"); + for (final String regType : regTypes) { + options.put("regularization", regType); + Assert.assertTrue(DenseOptimizerFactory.create(8, options) instanceof Optimizer.SGD); + Assert.assertTrue(SparseOptimizerFactory.create(8, options) instanceof Optimizer.SGD); + } + options.put("optimizer", "AdaDelta"); + for (final String regType : regTypes) { + options.put("regularization", regType); + Assert.assertTrue(DenseOptimizerFactory.create(8, options) instanceof DenseOptimizerFactory.AdaDelta); + Assert.assertTrue(SparseOptimizerFactory.create(8, options) instanceof SparseOptimizerFactory.AdaDelta); + } + options.put("optimizer", "AdaGrad"); + for (final String regType : regTypes) { + options.put("regularization", regType); + Assert.assertTrue(DenseOptimizerFactory.create(8, options) instanceof DenseOptimizerFactory.AdaGrad); + Assert.assertTrue(SparseOptimizerFactory.create(8, options) instanceof SparseOptimizerFactory.AdaGrad); + } + options.put("optimizer", "Adam"); + for (final String regType : regTypes) { + options.put("regularization", regType); + Assert.assertTrue(DenseOptimizerFactory.create(8, options) instanceof DenseOptimizerFactory.Adam); + Assert.assertTrue(SparseOptimizerFactory.create(8, options) instanceof SparseOptimizerFactory.Adam); + } + + // We need special handling for `Optimizer#RDA` + options.put("optimizer", "AdaGrad"); + options.put("regularization", "RDA"); + Assert.assertTrue(DenseOptimizerFactory.create(8, options) instanceof DenseOptimizerFactory.AdagradRDA); + Assert.assertTrue(SparseOptimizerFactory.create(8, options) instanceof SparseOptimizerFactory.AdagradRDA); + + // `SGD`, `AdaDelta`, and `Adam` currently does not support `RDA` + for (final String optimizerType : new String[] {"SGD", "AdaDelta", "Adam"}) { + options.put("optimizer", optimizerType); + try { + DenseOptimizerFactory.create(8, options); + Assert.fail(); + } catch (IllegalArgumentException e) { + // tests passed + } + try { + SparseOptimizerFactory.create(8, options); + Assert.fail(); + } catch (IllegalArgumentException e) { + // tests passed + } + } + } + + private void testUpdateWeights(Optimizer optimizer, int numUpdates, int initSize) { + final float[] weights = new float[initSize * 2]; + final Random rnd = new Random(); + try { + for (int i = 0; i < numUpdates; i++) { + int index = rnd.nextInt(initSize); + weights[index] = optimizer.update(index, weights[index], 0.1f); + } + for (int i = 0; i < numUpdates; i++) { + int index = rnd.nextInt(initSize * 2); + weights[index] = optimizer.update(index, weights[index], 0.1f); + } + } catch (Exception e) { + Assert.fail("failed to update weights: " + e.getMessage()); + } + } + + private void testOptimizer(final Map<String, String> options, int numUpdates, int initSize) { + final Map<String, String> testOptions = new HashMap<String, String>(options); + final String[] regTypes = new String[] {"NO", "L1", "L2", "RDA", "ElasticNet"}; + for (final String regType : regTypes) { + options.put("regularization", regType); + testUpdateWeights(DenseOptimizerFactory.create(1024, testOptions), 65536, 1024); + testUpdateWeights(SparseOptimizerFactory.create(1024, testOptions), 65536, 1024); + } + } + + @Test + public void testSGDOptimizer() { + final Map<String, String> options = new HashMap<String, String>(); + options.put("optimizer", "SGD"); + testOptimizer(options, 65536, 1024); + } + + @Test + public void testAdaDeltaOptimizer() { + final Map<String, String> options = new HashMap<String, String>(); + options.put("optimizer", "AdaDelta"); + testOptimizer(options, 65536, 1024); + } + + @Test + public void testAdaGradOptimizer() { + final Map<String, String> options = new HashMap<String, String>(); + options.put("optimizer", "AdaGrad"); + testOptimizer(options, 65536, 1024); + } + + @Test + public void testAdamOptimizer() { + final Map<String, String> options = new HashMap<String, String>(); + options.put("optimizer", "Adam"); + testOptimizer(options, 65536, 1024); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/test/java/hivemall/regression/GeneralRegressionUDTFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/regression/GeneralRegressionUDTFTest.java b/core/src/test/java/hivemall/regression/GeneralRegressionUDTFTest.java new file mode 100644 index 0000000..15dcc22 --- /dev/null +++ b/core/src/test/java/hivemall/regression/GeneralRegressionUDTFTest.java @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.regression; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; + +import org.junit.Assert; +import org.junit.Test; + +import javax.annotation.Nonnull; + +public class GeneralRegressionUDTFTest { + private static final boolean DEBUG = false; + + @Test(expected = UDFArgumentException.class) + public void testUnsupportedOptimizer() throws Exception { + GeneralRegressionUDTF udtf = new GeneralRegressionUDTF(); + ObjectInspector floatOI = PrimitiveObjectInspectorFactory.javaFloatObjectInspector; + ObjectInspector stringOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector; + ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector(stringOI); + ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-opt UnsupportedOpt"); + + udtf.initialize(new ObjectInspector[] {stringListOI, floatOI, params}); + } + + @Test(expected = UDFArgumentException.class) + public void testUnsupportedLossFunction() throws Exception { + GeneralRegressionUDTF udtf = new GeneralRegressionUDTF(); + ObjectInspector floatOI = PrimitiveObjectInspectorFactory.javaFloatObjectInspector; + ObjectInspector stringOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector; + ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector(stringOI); + ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-loss UnsupportedLoss"); + + udtf.initialize(new ObjectInspector[] {stringListOI, floatOI, params}); + } + + @Test(expected = UDFArgumentException.class) + public void testInvalidLossFunction() throws Exception { + GeneralRegressionUDTF udtf = new GeneralRegressionUDTF(); + ObjectInspector floatOI = PrimitiveObjectInspectorFactory.javaFloatObjectInspector; + ObjectInspector stringOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector; + ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector(stringOI); + ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-loss HingeLoss"); + + udtf.initialize(new ObjectInspector[] {stringListOI, floatOI, params}); + } + + @Test(expected = UDFArgumentException.class) + public void testUnsupportedRegularization() throws Exception { + GeneralRegressionUDTF udtf = new GeneralRegressionUDTF(); + ObjectInspector floatOI = PrimitiveObjectInspectorFactory.javaFloatObjectInspector; + ObjectInspector stringOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector; + ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector(stringOI); + ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-reg UnsupportedReg"); + + udtf.initialize(new ObjectInspector[] {stringListOI, floatOI, params}); + } + + private void run(@Nonnull String options) throws Exception { + println(options); + + int numSamples = 100; + + float x1Min = -5.f, x1Max = 5.f; + float x1Step = (x1Max - x1Min) / numSamples; + + float x2Min = -3.f, x2Max = 3.f; + float x2Step = (x2Max - x2Min) / numSamples; + + ArrayList<List<String>> samplesList = new ArrayList<List<String>>(numSamples); + ArrayList<Float> ys = new ArrayList<Float>(numSamples); + float x1 = x1Min, x2 = x2Min; + + for (int i = 0; i < numSamples; i++) { + samplesList.add(Arrays.asList("1:" + String.valueOf(x1), "2:" + String.valueOf(x2))); + + ys.add(x1 * 0.5f); + + x1 += x1Step; + x2 += x2Step; + } + + int numTrain = (int) (numSamples * 0.8); + int maxIter = 512; + + GeneralRegressionUDTF udtf = new GeneralRegressionUDTF(); + ObjectInspector floatOI = PrimitiveObjectInspectorFactory.javaFloatObjectInspector; + ObjectInspector stringOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector; + ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector(stringOI); + ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, options); + + udtf.initialize(new ObjectInspector[] {stringListOI, floatOI, params}); + + float cumLossPrev = Float.MAX_VALUE; + float cumLoss = 0.f; + int it = 0; + while ((it < maxIter) && (Math.abs(cumLoss - cumLossPrev) > 1e-3f)) { + cumLossPrev = cumLoss; + udtf.resetCumulativeLoss(); + for (int i = 0; i < numTrain; i++) { + udtf.process(new Object[] {samplesList.get(i), (Float) ys.get(i)}); + } + cumLoss = udtf.getCumulativeLoss(); + println("Iter: " + ++it + ", Cumulative loss: " + cumLoss); + } + Assert.assertTrue(cumLoss / numTrain < 0.1f); + + float accum = 0.f; + + for (int i = numTrain; i < numSamples; i++) { + float y = ys.get(i).floatValue(); + + float predicted = udtf.predict(udtf.parseFeatures(samplesList.get(i))); + println("Predicted: " + predicted + ", Actual: " + y); + + accum += Math.abs(y - predicted); + } + + float err = accum / (numSamples - numTrain); + println("Mean absolute error: " + err); + Assert.assertTrue(err < 0.2f); + } + + @Test + public void test() throws Exception { + String[] optimizers = new String[] {"SGD", "AdaDelta", "AdaGrad", "Adam"}; + String[] regularizations = new String[] {"NO", "L1", "L2", "ElasticNet", "RDA"}; + String[] lossFunctions = new String[] {"SquaredLoss", "QuantileLoss", + "EpsilonInsensitiveLoss", "HuberLoss"}; + + for (String opt : optimizers) { + for (String reg : regularizations) { + if (reg == "RDA" && opt != "AdaGrad") { + continue; + } + + for (String loss : lossFunctions) { + String options = "-opt " + opt + " -reg " + reg + " -loss " + loss + + " -lambda 1e-6 -eta0 1e-1"; + + // sparse + run(options); + + // mini-batch + if (opt != "AdaGrad") { + options += " -mini_batch 10"; + run(options); + } + + // dense + options += " -dense"; + run(options); + } + } + } + } + + private static void println(String msg) { + if (DEBUG) { + System.out.println(msg); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/docs/gitbook/SUMMARY.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/SUMMARY.md b/docs/gitbook/SUMMARY.md index 809a548..a4dccda 100644 --- a/docs/gitbook/SUMMARY.md +++ b/docs/gitbook/SUMMARY.md @@ -71,37 +71,44 @@ * [Data Generation](eval/datagen.md) * [Logistic Regression data generation](eval/lr_datagen.md) + +## Part V - Prediction -## Part V - Binary classification +* [How Prediction Works](misc/prediction.md) +* [Regression](regression/general.md) +* [Binary Classification](binaryclass/general.md) + +## Part VI - Binary classification tutorials -* [a9a Tutorial](binaryclass/a9a.md) +* [a9a](binaryclass/a9a.md) * [Data preparation](binaryclass/a9a_dataset.md) * [Logistic Regression](binaryclass/a9a_lr.md) * [Mini-batch Gradient Descent](binaryclass/a9a_minibatch.md) -* [News20 Tutorial](binaryclass/news20.md) +* [News20](binaryclass/news20.md) * [Data preparation](binaryclass/news20_dataset.md) * [Perceptron, Passive Aggressive](binaryclass/news20_pa.md) * [CW, AROW, SCW](binaryclass/news20_scw.md) * [AdaGradRDA, AdaGrad, AdaDelta](binaryclass/news20_adagrad.md) -* [KDD2010a Tutorial](binaryclass/kdd2010a.md) +* [KDD2010a](binaryclass/kdd2010a.md) * [Data preparation](binaryclass/kdd2010a_dataset.md) * [PA, CW, AROW, SCW](binaryclass/kdd2010a_scw.md) -* [KDD2010b Tutorial](binaryclass/kdd2010b.md) +* [KDD2010b](binaryclass/kdd2010b.md) * [Data preparation](binaryclass/kdd2010b_dataset.md) * [AROW](binaryclass/kdd2010b_arow.md) -* [Webspam Tutorial](binaryclass/webspam.md) +* [Webspam](binaryclass/webspam.md) * [Data pareparation](binaryclass/webspam_dataset.md) * [PA1, AROW, SCW](binaryclass/webspam_scw.md) -* [Kaggle Titanic Tutorial](binaryclass/titanic_rf.md) +* [Kaggle Titanic](binaryclass/titanic_rf.md) + -## Part VI - Multiclass classification +## Part VII - Multiclass classification tutorials -* [News20 Multiclass Tutorial](multiclass/news20.md) +* [News20 Multiclass](multiclass/news20.md) * [Data preparation](multiclass/news20_dataset.md) * [Data preparation for one-vs-the-rest classifiers](multiclass/news20_one-vs-the-rest_dataset.md) * [PA](multiclass/news20_pa.md) @@ -109,24 +116,24 @@ * [Ensemble learning](multiclass/news20_ensemble.md) * [one-vs-the-rest classifier](multiclass/news20_one-vs-the-rest.md) -* [Iris Tutorial](multiclass/iris.md) +* [Iris](multiclass/iris.md) * [Data preparation](multiclass/iris_dataset.md) * [SCW](multiclass/iris_scw.md) * [RandomForest](multiclass/iris_randomforest.md) -## Part VII - Regression +## Part VIII - Regression tutorials -* [E2006-tfidf regression Tutorial](regression/e2006.md) +* [E2006-tfidf regression](regression/e2006.md) * [Data preparation](regression/e2006_dataset.md) * [Passive Aggressive, AROW](regression/e2006_arow.md) -* [KDDCup 2012 track 2 CTR prediction Tutorial](regression/kddcup12tr2.md) +* [KDDCup 2012 track 2 CTR prediction](regression/kddcup12tr2.md) * [Data preparation](regression/kddcup12tr2_dataset.md) * [Logistic Regression, Passive Aggressive](regression/kddcup12tr2_lr.md) * [Logistic Regression with Amplifier](regression/kddcup12tr2_lr_amplify.md) * [AdaGrad, AdaDelta](regression/kddcup12tr2_adagrad.md) -## Part VIII - Recommendation +## Part IX - Recommendation * [Collaborative Filtering](recommend/cf.md) * [Item-based Collaborative Filtering](recommend/item_based_cf.md) @@ -143,22 +150,22 @@ * [Factorization Machine](recommend/movielens_fm.md) * [10-fold Cross Validation (Matrix Factorization)](recommend/movielens_cv.md) -## Part IX - Anomaly Detection +## Part X - Anomaly Detection * [Outlier Detection using Local Outlier Factor (LOF)](anomaly/lof.md) * [Change-Point Detection using Singular Spectrum Transformation (SST)](anomaly/sst.md) * [ChangeFinder: Detecting Outlier and Change-Point Simultaneously](anomaly/changefinder.md) -## Part X - Clustering +## Part XI - Clustering * [Latent Dirichlet Allocation](clustering/lda.md) * [Probabilistic Latent Semantic Analysis](clustering/plsa.md) -## Part XI - GeoSpatial functions +## Part XII - GeoSpatial functions * [Lat/Lon functions](geospatial/latlon.md) -## Part XII - Hivemall on Spark +## Part XIII - Hivemall on Spark * [Getting Started](spark/getting_started/README.md) * [Installation](spark/getting_started/installation.md) @@ -173,7 +180,7 @@ * [Top-k Join processing](spark/misc/topk_join.md) * [Other utility functions](spark/misc/functions.md) -## Part XIII - Hivemall on Docker +## Part XIV - Hivemall on Docker * [Getting Started](docker/getting_started.md) http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/docs/gitbook/binaryclass/general.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/binaryclass/general.md b/docs/gitbook/binaryclass/general.md new file mode 100644 index 0000000..50ea688 --- /dev/null +++ b/docs/gitbook/binaryclass/general.md @@ -0,0 +1,132 @@ +<!-- + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +--> + +Hivemall has a generic function for classification: `train_classifier`. Compared to the other functions we will see in the later chapters, `train_classifier` provides simpler and configureable generic interface which can be utilized to build binary classification models in a variety of settings. + +Here, we briefly introduce usage of the function. Before trying sample queries, you first need to prepare [a9a data](http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary.html#a9a). See [our a9a tutorial page](a9a_dataset.md) for further instructions. + +<!-- toc --> + +> #### Note +> This feature is supported from Hivemall v0.5-rc.1 or later. + +# Preparation + +- Set `total_steps` ideally be `count(1) / {# of map tasks}`: + ``` + hive> select count(1) from a9a_train; + hive> set hivevar:total_steps=32561; + ``` +- Set `n_samples` to compute accuracy of prediction: + ``` + hive> select count(1) from a9a_test; + hive> set hivevar:n_samples=16281; + ``` + +# Training + +```sql +create table classification_model as +select + feature, + avg(weight) as weight +from + ( + select + train_classifier(add_bias(features), label, '-loss logloss -opt SGD -reg no -eta simple -total_steps ${total_steps}') as (feature, weight) + from + a9a_train + ) t +group by feature; +``` + +# Prediction & evaluation + +```sql +WITH test_exploded as ( + select + rowid, + label, + extract_feature(feature) as feature, + extract_weight(feature) as value + from + a9a_test LATERAL VIEW explode(add_bias(features)) t AS feature +), +predict as ( + select + t.rowid, + sigmoid(sum(m.weight * t.value)) as prob, + CAST((case when sigmoid(sum(m.weight * t.value)) >= 0.5 then 1.0 else 0.0 end) as FLOAT) as label + from + test_exploded t LEFT OUTER JOIN + classification_model m ON (t.feature = m.feature) + group by + t.rowid +), +submit as ( + select + t.label as actual, + pd.label as predicted, + pd.prob as probability + from + a9a_test t JOIN predict pd + on (t.rowid = pd.rowid) +) +select count(1) / ${n_samples} from submit +where actual = predicted; +``` + +# Comparison with the other binary classifiers + +In the next part of this user guide, our binary classification tutorials introduce many different functions: + +- [Logistic Regression](a9a_lr.md) + - and [its mini-batch variant](a9a_minibatch.md) +- [Perceptron](news20_pa.md#perceptron) +- [Passive Aggressive](news20_pa.md#passive-aggressive) +- [CW](news20_scw.md#confidence-weighted-cw) +- [AROW](news20_scw.md#adaptive-regularization-of-weight-vectors-arow) +- [SCW](news20_scw.md#soft-confidence-weighted-scw1) +- [AdaGradRDA](news20_adagrad.md#adagradrda) +- [AdaGrad](news20_adagrad.md#adagrad) +- [AdaDelta](news20_adagrad.md#adadelta) + +All of them actually have the same interface, but mathematical formulation and its implementation differ from each other. + +In particular, the above sample queries are almost same as [a9a tutorial using Logistic Regression](a9a_lr.md). The difference is only in a choice of training function: `logress()` vs. `train_classifier()`. + +However, at the same time, the options `-loss logloss -opt SGD -reg no -eta simple -total_steps ${total_steps}` for `train_classifier` indicates that Hivemall uses the generic classifier as Logistic Regressor (`logress`). Hence, the accuracy of prediction based on either `logress` and `train_classifier` should be same under the configuration. + +In addition, `train_classifier` supports the `-mini_batch` option in a similar manner to [what `logress` does](a9a_minibatch.md). Thus, following two training queries show the same results: + +```sql +select + logress(add_bias(features), label, '-total_steps ${total_steps} -mini_batch 10') as (feature, weight) +from + a9a_train +``` + +```sql +select + train_classifier(add_bias(features), label, '-loss logloss -opt SGD -reg no -eta simple -total_steps ${total_steps} -mini_batch 10') as (feature, weight) +from + a9a_train +``` + +Likewise, you can generate many different classifiers based on its options. \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/docs/gitbook/misc/prediction.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/misc/prediction.md b/docs/gitbook/misc/prediction.md new file mode 100644 index 0000000..53fe03f --- /dev/null +++ b/docs/gitbook/misc/prediction.md @@ -0,0 +1,137 @@ +<!-- + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +--> + +<!-- toc --> + +# What is "prediction problem"? + +In a context of machine learning, numerous tasks can be seen as **prediction problem**. For example, this user guide provides solutions for: + +- [spam detection](../binaryclass/webspam.md) +- [news article classification](../multiclass/news20.md) +- [click-through-rate estimation](../regression/kddcup12tr2.md) + +For any kinds of prediction problems, we generally provide a set of input-output pairs as: + +- **Input:** Set of features + - e.g., `["1:0.001","4:0.23","35:0.0035",...]` +- **Output:** Target value + - e.g., 1, 0, 0.54, 42.195, ... + +Once a prediction model has been constructed based on the samples, the model can make prediction for unforeseen inputs. + +In order to train prediction models, an algorithm so-called ***stochastic gradient descent*** (SGD) is normally applied. You can learn more about this from the following external resources: + +- [scikit-learn documentation](http://scikit-learn.org/stable/modules/sgd.html) +- [Spark MLlib documentation](http://spark.apache.org/docs/latest/mllib-optimization.html) + +Importantly, depending on types of output value, prediction problem can be categorized into **regression** and **classification** problem. + +# Regression + +The goal of regression is to predict **real values** as shown below: + +| features (input) | target real value (output) | +|:---|:---:| +|["1:0.001","4:0.23","35:0.0035",...] | 21.3 | +|["1:0.2","3:0.1","13:0.005",...] | 6.2 | +|["5:1.3","22:0.0.089","77:0.0001",...] | 17.1 | +| ... | ... | + +In practice, target values could be any of small/large float/int negative/positive values. [Our CTR prediction tutorial](../regression/kddcup12tr2.md) solves regression problem with small floating point target values in a 0-1 range, for example. + +While there are several ways to realize regression by using Hivemall, `train_regression()` is one of the most flexible functions. This feature is explained in: [Regression](../regression/general.md). + +# Classification + +In contrast to regression, output for classification problems should be (integer) **labels**: + +| features (input) | label (output) | +|:---|:---:| +|["1:0.001","4:0.23","35:0.0035",...] | 0 | +|["1:0.2","3:0.1","13:0.005",...] | 1 | +|["5:1.3","22:0.0.089","77:0.0001",...] | 1 | +| ... | ... | + +In case the number of possible labels is 2 (0/1 or -1/1), the problem is **binary classification**, and Hivemall's `train_classifier()` function enables you to build binary classifiers. [Binary Classification](../binaryclass/general.md) demonstrates how to use the function. + +Another type of classification problems is **multi-class classification**. This task assumes that the number of possible labels is more than 2. We need to use different functions for the multi-class problems, and our [news20](../multiclass/news20.md) and [iris](../multiclass/iris.md) tutorials would be helpful. + +# Mathematical formulation of generic prediction model + +Here, we briefly explain about how prediction model is constructed. + +First and foremost, we represent **input** and **output** for prediction models as follows: + +- **Input:** a vector $$\mathbf{x}$$ +- **Output:** a value $$y$$ + +For a set of samples $$(\mathbf{x}_1, y_1), (\mathbf{x}_2, y_2), \cdots, (\mathbf{x}_n, y_n)$$, the goal of prediction algorithms is to find a weight vector (i.e., parameters) $$\mathbf{w}$$ by minimizing the following error: + +$$ +E(\mathbf{w}) := \frac{1}{n} \sum_{i=1}^{n} L(\mathbf{w}; \mathbf{x}_i, y_i) + \lambda R(\mathbf{w}) +$$ + +In the above formulation, there are two auxiliary functions we have to know: + +- $$L(\mathbf{w}; \mathbf{x}_i, y_i)$$ + - **Loss function** for a single sample $$(\mathbf{x}_i, y_i)$$ and given $$\mathbf{w}$$. + - If this function produces small values, it means the parameter $$\mathbf{w}$$ is successfully learnt. +- $$R(\mathbf{w})$$ + - **Regularization function** for the current parameter $$\mathbf{w}$$. + - It prevents failing to a negative condition so-called **over-fitting**. + +($$\lambda$$ is a small value which controls the effect of regularization function.) + +Eventually, minimizing the function $$E(\mathbf{w})$$ can be implemented by the SGD technique as described before, and $$\mathbf{w}$$ itself is used as a "model" for future prediction. + +Interestingly, depending on a choice of loss and regularization function, prediction model you obtained will behave differently; even if one combination could work as a classifier, another choice might be appropriate for regression. + +Below we list possible options for `train_regression` and `train_classifier`, and this is the reason why these two functions are the most flexible in Hivemall: + +- Loss function: `-loss`, `-loss_function` + - For `train_regression` + - SquaredLoss + - QuantileLoss + - EpsilonInsensitiveLoss + - HuberLoss + - For `train_classifier` + - HingeLoss + - LogLoss + - SquaredHingeLoss + - ModifiedHuberLoss + - SquaredLoss + - QuantileLoss + - EpsilonInsensitiveLoss + - HuberLoss +- Regularization function: `-reg`, `-regularization` + - L1 + - L2 + - ElasticNet + - RDA + +Additionally, there are several variants of the SGD technique, and it is also configureable as: + +- Optimizer `-opt`, `-optimizer` + - SGD + - AdaGrad + - AdaDelta + - Adam + +In practice, you can try different combinations of the options in order to achieve higher prediction accuracy. \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/docs/gitbook/regression/general.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/regression/general.md b/docs/gitbook/regression/general.md new file mode 100644 index 0000000..dee0719 --- /dev/null +++ b/docs/gitbook/regression/general.md @@ -0,0 +1,74 @@ +<!-- + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +--> + +In our regression tutorials, you can tackle realistic prediction problems by using several Hivemall's regression features such as: + +- [PA1a](e2006_arow.html#pa1a) +- [PA2a](e2006_arow.html#pa2a) +- [AROW](e2006_arow.html#arow) +- [AROWe](e2006_arow.html#arowe) + +Our `train_regression` function enables you to solve the regression problems with flexible configureable options. Let us try the function below. + +It should be noted that the sample queries require you to prepare [E2006-tfidf data](http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/regression.html#E2006-tfidf). See [our E2006-tfidf tutorial page](../regression/e2006_dataset.md) for further instructions. + +<!-- toc --> + +> #### Note +> This feature is supported from Hivemall v0.5-rc.1 or later. + +# Training + +```sql +create table e2006tfidf_regression_model as +select + feature, + avg(weight) as weight +from ( + select + train_regression(features,target,'-loss squaredloss -opt AdaGrad -reg no') as (feature,weight) + from + e2006tfidf_train_x3 +) t +group by feature; +``` + +# Prediction & evaluation + +```sql +WITH predict as ( + select + t.rowid, + sum(m.weight * t.value) as predicted + from + e2006tfidf_test_exploded t LEFT OUTER JOIN + e2006tfidf_regression_model m ON (t.feature = m.feature) + group by + t.rowid +), +submit as ( + select + t.target as actual, + p.predicted as predicted + from + e2006tfidf_test t JOIN predict p + on (t.rowid = p.rowid) +) +select rmse(predicted, actual) from submit; +``` http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/mixserv/src/test/java/hivemall/mix/server/MixServerTest.java ---------------------------------------------------------------------- diff --git a/mixserv/src/test/java/hivemall/mix/server/MixServerTest.java b/mixserv/src/test/java/hivemall/mix/server/MixServerTest.java index 3a65da8..2b475c1 100644 --- a/mixserv/src/test/java/hivemall/mix/server/MixServerTest.java +++ b/mixserv/src/test/java/hivemall/mix/server/MixServerTest.java @@ -18,9 +18,9 @@ */ package hivemall.mix.server; -import hivemall.model.DenseModel; +import hivemall.model.NewDenseModel; import hivemall.model.PredictionModel; -import hivemall.model.SparseModel; +import hivemall.model.NewSparseModel; import hivemall.model.WeightValue; import hivemall.mix.MixMessage.MixEventName; import hivemall.mix.client.MixClient; @@ -55,7 +55,7 @@ public class MixServerTest extends HivemallTestBase { waitForState(server, ServerState.RUNNING); - PredictionModel model = new DenseModel(16777216, false); + PredictionModel model = new NewDenseModel(16777216); model.configureClock(); MixClient client = null; try { @@ -93,7 +93,7 @@ public class MixServerTest extends HivemallTestBase { waitForState(server, ServerState.RUNNING); - PredictionModel model = new DenseModel(16777216, false); + PredictionModel model = new NewDenseModel(16777216); model.configureClock(); MixClient client = null; try { @@ -151,7 +151,7 @@ public class MixServerTest extends HivemallTestBase { } private static void invokeClient(String groupId, int serverPort) throws InterruptedException { - PredictionModel model = new DenseModel(16777216, false); + PredictionModel model = new NewDenseModel(16777216); model.configureClock(); MixClient client = null; try { @@ -296,10 +296,10 @@ public class MixServerTest extends HivemallTestBase { serverExec.shutdown(); } - private static void invokeClient01(String groupId, int serverPort, boolean denseModel, - boolean cancelMix) throws InterruptedException { - PredictionModel model = denseModel ? new DenseModel(100, false) : new SparseModel(100, - false); + private static void invokeClient01(String groupId, int serverPort, boolean denseModel, boolean cancelMix) + throws InterruptedException { + PredictionModel model = denseModel ? new NewDenseModel(100) + : new NewSparseModel(100, false); model.configureClock(); MixClient client = null; try { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/resources/ddl/define-all-as-permanent.hive ---------------------------------------------------------------------- diff --git a/resources/ddl/define-all-as-permanent.hive b/resources/ddl/define-all-as-permanent.hive index 52f3fab..0485890 100644 --- a/resources/ddl/define-all-as-permanent.hive +++ b/resources/ddl/define-all-as-permanent.hive @@ -13,6 +13,9 @@ CREATE FUNCTION hivemall_version as 'hivemall.HivemallVersionUDF' USING JAR '${h -- binary classification -- --------------------------- +DROP FUNCTION IF EXISTS train_classifier; +CREATE FUNCTION train_classifier as 'hivemall.classifier.GeneralClassifierUDTF' USING JAR '${hivemall_jar}'; + DROP FUNCTION IF EXISTS train_perceptron; CREATE FUNCTION train_perceptron as 'hivemall.classifier.PerceptronUDTF' USING JAR '${hivemall_jar}'; @@ -334,6 +337,13 @@ CREATE FUNCTION tf as 'hivemall.ftvec.text.TermFrequencyUDAF' USING JAR '${hivem -- Regression functions -- -------------------------- +DROP FUNCTION IF EXISTS train_regression; +CREATE FUNCTION train_regression as 'hivemall.regression.GeneralRegressionUDTF' USING JAR '${hivemall_jar}'; + +DROP FUNCTION IF EXISTS train_logregr; +CREATE FUNCTION train_logregr as 'hivemall.regression.LogressUDTF' USING JAR '${hivemall_jar}'; + +-- alias for backward compatibility DROP FUNCTION IF EXISTS logress; CREATE FUNCTION logress as 'hivemall.regression.LogressUDTF' USING JAR '${hivemall_jar}'; @@ -688,3 +698,4 @@ CREATE FUNCTION xgboost_predict AS 'hivemall.xgboost.tools.XGBoostPredictUDTF' U DROP FUNCTION xgboost_multiclass_predict; CREATE FUNCTION xgboost_multiclass_predict AS 'hivemall.xgboost.tools.XGBoostMulticlassPredictUDTF' USING JAR '${hivemall_jar}'; +======= http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/resources/ddl/define-all.hive ---------------------------------------------------------------------- diff --git a/resources/ddl/define-all.hive b/resources/ddl/define-all.hive index f752c19..bd79a6b 100644 --- a/resources/ddl/define-all.hive +++ b/resources/ddl/define-all.hive @@ -2,13 +2,16 @@ -- Hivemall: Hive scalable Machine Learning Library ----------------------------------------------------------------------------- -drop temporary function if exists hivemall_version; +drop temporary function if exists hivemall_version; create temporary function hivemall_version as 'hivemall.HivemallVersionUDF'; --------------------------- -- binary classification -- --------------------------- +drop temporary function if exists train_classifier; +create temporary function train_classifier as 'hivemall.classifier.GeneralClassifierUDTF'; + drop temporary function if exists train_perceptron; create temporary function train_perceptron as 'hivemall.classifier.PerceptronUDTF'; @@ -330,6 +333,9 @@ create temporary function tf as 'hivemall.ftvec.text.TermFrequencyUDAF'; -- Regression functions -- -------------------------- +drop temporary function if exists train_regression; +create temporary function train_regression as 'hivemall.regression.GeneralRegressionUDTF'; + drop temporary function if exists logress; create temporary function logress as 'hivemall.regression.LogressUDTF'; @@ -698,5 +704,3 @@ log(10, n_docs / max2(1,df_t)) + 1.0; create temporary macro tfidf(tf FLOAT, df_t DOUBLE, n_docs DOUBLE) tf * (log(10, n_docs / max2(1,df_t)) + 1.0); - - http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/resources/ddl/define-all.spark ---------------------------------------------------------------------- diff --git a/resources/ddl/define-all.spark b/resources/ddl/define-all.spark index e4c4f1a..261fb8d 100644 --- a/resources/ddl/define-all.spark +++ b/resources/ddl/define-all.spark @@ -11,6 +11,9 @@ sqlContext.sql("CREATE TEMPORARY FUNCTION hivemall_version AS 'hivemall.Hivemall * Binary classification */ +sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS train_classifier") +sqlContext.sql("CREATE TEMPORARY FUNCTION train_classifier AS 'hivemall.classifier.GeneralClassifierUDTF'") + sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS train_perceptron") sqlContext.sql("CREATE TEMPORARY FUNCTION train_perceptron AS 'hivemall.classifier.PerceptronUDTF'") @@ -333,6 +336,9 @@ sqlContext.sql("CREATE TEMPORARY FUNCTION tf AS 'hivemall.ftvec.text.TermFrequen * Regression functions */ +sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS train_regression") +sqlContext.sql("CREATE TEMPORARY FUNCTION train_regression AS 'hivemall.regression.GeneralRegressionUDTF'") + sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS logress") sqlContext.sql("CREATE TEMPORARY FUNCTION logress AS 'hivemall.regression.LogressUDTF'") http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/resources/ddl/define-udfs.td.hql ---------------------------------------------------------------------- diff --git a/resources/ddl/define-udfs.td.hql b/resources/ddl/define-udfs.td.hql index 547bf84..7fba0d6 100644 --- a/resources/ddl/define-udfs.td.hql +++ b/resources/ddl/define-udfs.td.hql @@ -166,6 +166,8 @@ create temporary function tile as 'hivemall.geospatial.TileUDF'; create temporary function map_url as 'hivemall.geospatial.MapURLUDF'; create temporary function l2_norm as 'hivemall.tools.math.L2NormUDAF'; create temporary function dimsum_mapper as 'hivemall.knn.similarity.DIMSUMMapperUDTF'; +create temporary function train_classifier as 'hivemall.classifier.GeneralClassifierUDTF'; +create temporary function train_regression as 'hivemall.regression.GeneralRegressionUDTF'; -- NLP features create temporary function tokenize_ja as 'hivemall.nlp.tokenizer.KuromojiUDF'; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/spark/spark-2.0/src/test/scala/hivemall/mix/server/MixServerSuite.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.0/src/test/scala/hivemall/mix/server/MixServerSuite.scala b/spark/spark-2.0/src/test/scala/hivemall/mix/server/MixServerSuite.scala index dbb818b..cadc852 100644 --- a/spark/spark-2.0/src/test/scala/hivemall/mix/server/MixServerSuite.scala +++ b/spark/spark-2.0/src/test/scala/hivemall/mix/server/MixServerSuite.scala @@ -22,10 +22,12 @@ import java.util.Random import java.util.concurrent.{Executors, ExecutorService, TimeUnit} import java.util.logging.Logger -import hivemall.mix.MixMessage.MixEventName import hivemall.mix.client.MixClient +import hivemall.mix.MixMessage.MixEventName import hivemall.mix.server.MixServer.ServerState -import hivemall.model.{DenseModel, PredictionModel, WeightValue} +import hivemall.model.{DenseModel, PredictionModel} +import hivemall.model.{NewDenseModel, PredictionModel} +import hivemall.model.WeightValue import hivemall.utils.io.IOUtils import hivemall.utils.lang.CommandLineUtils import hivemall.utils.net.NetUtils @@ -96,7 +98,7 @@ class MixServerSuite extends FunSuite with BeforeAndAfter { ignore(testName) { val clients = Executors.newCachedThreadPool() val numClients = nclient - val models = (0 until numClients).map(i => new DenseModel(ndims, false)) + val models = (0 until numClients).map(i => new NewDenseModel(ndims, false)) (0 until numClients).map { i => clients.submit(new Runnable() { override def run(): Unit = {
