Close #13: Implement Kernel Expansion Passive Aggressive Classification
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/391e7f1c Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/391e7f1c Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/391e7f1c Branch: refs/heads/master Commit: 391e7f1c65c7084acfbfe6b2491765a4dcd8212a Parents: 273851f Author: Sotaro Sugimoto <sotaro.sugim...@gmail.com> Authored: Mon Jan 16 16:01:46 2017 +0900 Committer: myui <yuin...@gmail.com> Committed: Mon Jan 16 16:01:46 2017 +0900 ---------------------------------------------------------------------- .../src/main/java/hivemall/LearnerBaseUDTF.java | 4 +- .../anomaly/SingularSpectrumTransformUDF.java | 22 +- .../classifier/BinaryOnlineClassifierUDTF.java | 11 +- .../hivemall/classifier/KPAPredictUDAF.java | 228 + .../KernelExpansionPassiveAggressiveUDTF.java | 378 + .../hivemall/ensemble/bagging/VotedAvgUDAF.java | 8 +- .../ensemble/bagging/WeightVotedAvgUDAF.java | 8 +- .../ftvec/pairing/FeaturePairsUDTF.java | 232 + .../main/java/hivemall/model/FeatureValue.java | 29 +- .../utils/collections/FloatArrayList.java | 152 + .../collections/Int2FloatOpenHashTable.java | 3 + .../java/hivemall/utils/hadoop/HiveUtils.java | 51 + .../hivemall/utils/hashing/HashFunction.java | 14 + .../java/hivemall/utils/lang/Preconditions.java | 18 +- .../java/hivemall/utils/math/MathUtils.java | 18 +- .../hivemall/anomaly/ChangeFinder1DTest.java | 2 +- .../hivemall/anomaly/ChangeFinder2DTest.java | 2 +- ...ernelExpansionPassiveAggressiveUDTFTest.java | 158 + .../fm/FactorizationMachineUDTFTest.java | 16 +- .../FieldAwareFactorizationMachineUDTFTest.java | 16 +- .../mf/BPRMatrixFactorizationUDTFTest.java | 16 +- .../collections/Int2FloatOpenHashMapTest.java | 13 + .../utils/io/Base91OutputStreamTest.java | 2 +- .../hivemall/utils/lang/PreconditionsTest.java | 2 +- .../test/resources/hivemall/anomaly/cf1d.csv | 2503 - .../test/resources/hivemall/anomaly/cf1d.csv.gz | Bin 0 -> 17417 bytes .../hivemall/classifier/news20-small.binary.gz | Bin 0 -> 121787 bytes core/src/test/resources/hivemall/fm/5107786.txt | 200 - .../test/resources/hivemall/fm/5107786.txt.gz | Bin 0 -> 2712 bytes .../test/resources/hivemall/fm/bigdata.tr.txt | 200 - .../resources/hivemall/fm/bigdata.tr.txt.gz | Bin 0 -> 4544 bytes core/src/test/resources/hivemall/mf/ml1k.test | 20000 ----- .../src/test/resources/hivemall/mf/ml1k.test.gz | Bin 0 -> 78568 bytes core/src/test/resources/hivemall/mf/ml1k.train | 80000 ----------------- .../test/resources/hivemall/mf/ml1k.train.gz | Bin 0 -> 311183 bytes resources/ddl/define-all-as-permanent.hive | 9 + resources/ddl/define-all.hive | 9 + 37 files changed, 1367 insertions(+), 102957 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/391e7f1c/core/src/main/java/hivemall/LearnerBaseUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/LearnerBaseUDTF.java b/core/src/main/java/hivemall/LearnerBaseUDTF.java index f1ad99e..17c3ebc 100644 --- a/core/src/main/java/hivemall/LearnerBaseUDTF.java +++ b/core/src/main/java/hivemall/LearnerBaseUDTF.java @@ -161,11 +161,13 @@ public abstract class LearnerBaseUDTF extends UDTFWithOptions { return cl; } + @Nullable protected PredictionModel createModel() { return createModel(null); } - protected PredictionModel createModel(String label) { + @Nonnull + protected PredictionModel createModel(@Nullable String label) { PredictionModel model; final boolean useCovar = useCovariance(); if (dense_model) { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/391e7f1c/core/src/main/java/hivemall/anomaly/SingularSpectrumTransformUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/anomaly/SingularSpectrumTransformUDF.java b/core/src/main/java/hivemall/anomaly/SingularSpectrumTransformUDF.java index d699a95..1fac3e7 100644 --- a/core/src/main/java/hivemall/anomaly/SingularSpectrumTransformUDF.java +++ b/core/src/main/java/hivemall/anomaly/SingularSpectrumTransformUDF.java @@ -125,19 +125,17 @@ public final class SingularSpectrumTransformUDF extends UDFWithOptions { this._params.changepointThreshold = Primitives.parseDouble(cl.getOptionValue("th"), _params.changepointThreshold); - Preconditions.checkArgument(_params.w >= 2, "w must be greather than 1: " + _params.w, - UDFArgumentException.class); - Preconditions.checkArgument(_params.r >= 1, "r must be greater than 0: " + _params.r, - UDFArgumentException.class); - Preconditions.checkArgument(_params.k >= 1, "k must be greater than 0: " + _params.k, - UDFArgumentException.class); - Preconditions.checkArgument(_params.k >= _params.r, - "k must be equals to or greather than r: k=" + _params.k + ", r" + _params.r, - UDFArgumentException.class); + Preconditions.checkArgument(_params.w >= 2, UDFArgumentException.class, + "w must be greather than 1: " + _params.w); + Preconditions.checkArgument(_params.r >= 1, UDFArgumentException.class, + "r must be greater than 0: " + _params.r); + Preconditions.checkArgument(_params.k >= 1, UDFArgumentException.class, + "k must be greater than 0: " + _params.k); + Preconditions.checkArgument(_params.k >= _params.r, UDFArgumentException.class, + "k must be equals to or greather than r: k=" + _params.k + ", r" + _params.r); Preconditions.checkArgument(_params.changepointThreshold > 0.d - && _params.changepointThreshold < 1.d, - "changepointThreshold must be in range (0, 1): " + _params.changepointThreshold, - UDFArgumentException.class); + && _params.changepointThreshold < 1.d, UDFArgumentException.class, + "changepointThreshold must be in range (0, 1): " + _params.changepointThreshold); return cl; } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/391e7f1c/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java b/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java index c9274e4..b0e2efd 100644 --- a/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java +++ b/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java @@ -51,8 +51,8 @@ import org.apache.hadoop.io.FloatWritable; public abstract class BinaryOnlineClassifierUDTF extends LearnerBaseUDTF { private static final Log logger = LogFactory.getLog(BinaryOnlineClassifierUDTF.class); - private ListObjectInspector featureListOI; - private PrimitiveObjectInspector labelOI; + protected ListObjectInspector featureListOI; + protected PrimitiveObjectInspector labelOI; private boolean parseFeature; protected PredictionModel model; @@ -122,7 +122,7 @@ public abstract class BinaryOnlineClassifierUDTF extends LearnerBaseUDTF { } @Nullable - protected final FeatureValue[] parseFeatures(@Nonnull final List<?> features) { + FeatureValue[] parseFeatures(@Nonnull final List<?> features) { final int size = features.size(); if (size == 0) { return null; @@ -151,6 +151,7 @@ public abstract class BinaryOnlineClassifierUDTF extends LearnerBaseUDTF { assert (label == -1 || label == 0 || label == 1) : label; } + //@VisibleForTesting void train(List<?> features, int label) { FeatureValue[] featureVector = parseFeatures(features); train(featureVector, label); @@ -166,7 +167,7 @@ public abstract class BinaryOnlineClassifierUDTF extends LearnerBaseUDTF { } } - protected float predict(@Nonnull final FeatureValue[] features) { + float predict(@Nonnull final FeatureValue[] features) { float score = 0.f; for (FeatureValue f : features) {// a += w[i] * x[i] if (f == null) { @@ -247,7 +248,7 @@ public abstract class BinaryOnlineClassifierUDTF extends LearnerBaseUDTF { } @Override - public final void close() throws HiveException { + public void close() throws HiveException { super.close(); if (model != null) { int numForwarded = 0; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/391e7f1c/core/src/main/java/hivemall/classifier/KPAPredictUDAF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/classifier/KPAPredictUDAF.java b/core/src/main/java/hivemall/classifier/KPAPredictUDAF.java new file mode 100644 index 0000000..72409d9 --- /dev/null +++ b/core/src/main/java/hivemall/classifier/KPAPredictUDAF.java @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.classifier; + +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.lang.Preconditions; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AbstractAggregationBuffer; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationType; +import org.apache.hadoop.hive.ql.util.JavaDataModel; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; + +@Description( + name = "kpa_predict", + value = "_FUNC_(@Nonnull double xh, @Nonnull double xk, @Nullable float w0, @Nonnull float w1, @Nonnull float w2, @Nullable float w3)" + + " - Returns a prediction value in Double") +public final class KPAPredictUDAF extends AbstractGenericUDAFResolver { + + @Override + public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException { + if (parameters.length != 6) { + throw new UDFArgumentException( + "_FUNC_(double xh, double xk, float w0, float w1, float w2, float w3) takes exactly 6 arguments but got: " + + parameters.length); + } + if (!HiveUtils.isNumberTypeInfo(parameters[0])) { + throw new UDFArgumentTypeException(0, "Number type is expected for xh (1st argument): " + + parameters[0].getTypeName()); + } + if (!HiveUtils.isNumberTypeInfo(parameters[1])) { + throw new UDFArgumentTypeException(1, "Number type is expected for xk (2nd argument): " + + parameters[1].getTypeName()); + } + if (!HiveUtils.isNumberTypeInfo(parameters[2])) { + throw new UDFArgumentTypeException(2, "Number type is expected for w0 (3rd argument): " + + parameters[2].getTypeName()); + } + if (!HiveUtils.isNumberTypeInfo(parameters[3])) { + throw new UDFArgumentTypeException(3, "Number type is expected for w1 (4th argument): " + + parameters[3].getTypeName()); + } + if (!HiveUtils.isNumberTypeInfo(parameters[4])) { + throw new UDFArgumentTypeException(4, "Number type is expected for w2 (5th argument): " + + parameters[4].getTypeName()); + } + if (!HiveUtils.isNumberTypeInfo(parameters[5])) { + throw new UDFArgumentTypeException(5, "Number type is expected for w3 (6th argument): " + + parameters[5].getTypeName()); + } + + return new Evaluator(); + } + + public static class Evaluator extends GenericUDAFEvaluator { + + @Nullable + private transient PrimitiveObjectInspector xhOI, xkOI; + @Nullable + private transient PrimitiveObjectInspector w0OI, w1OI, w2OI, w3OI; + + public Evaluator() {} + + @Override + public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException { + super.init(m, parameters); + + // initialize input + if (m == Mode.PARTIAL1 || m == Mode.COMPLETE) {// from original data + this.xhOI = HiveUtils.asNumberOI(parameters[0]); + this.xkOI = HiveUtils.asNumberOI(parameters[1]); + this.w0OI = HiveUtils.asNumberOI(parameters[2]); + this.w1OI = HiveUtils.asNumberOI(parameters[3]); + this.w2OI = HiveUtils.asNumberOI(parameters[4]); + this.w3OI = HiveUtils.asNumberOI(parameters[5]); + } + + return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector; + } + + @Override + public AggrBuffer getNewAggregationBuffer() throws HiveException { + return new AggrBuffer(); + } + + @Override + public void reset(@SuppressWarnings("deprecation") AggregationBuffer agg) + throws HiveException { + AggrBuffer aggr = (AggrBuffer) agg; + aggr.reset(); + } + + @Override + public void iterate(@SuppressWarnings("deprecation") AggregationBuffer agg, + Object[] parameters) throws HiveException { + Preconditions.checkArgument(parameters.length == 6, HiveException.class); + + final AggrBuffer aggr = (AggrBuffer) agg; + + if (parameters[0] /* xh */!= null) { + double xh = HiveUtils.getDouble(parameters[0], xhOI); + if (parameters[1] /* xk */!= null) { + if (parameters[5] /* w3hk */== null) { + return; + } + // xh, xk, w3hk + double xk = HiveUtils.getDouble(parameters[1], xkOI); + double w3hk = HiveUtils.getDouble(parameters[5], w3OI); + aggr.addW3(xh, xk, w3hk); + } else { + if (parameters[3] /* w1h */== null) { + return; + } + // xh, w1h, w2h + Preconditions.checkNotNull(parameters[4], HiveException.class); + double w1h = HiveUtils.getDouble(parameters[3], w1OI); + double w2h = HiveUtils.getDouble(parameters[4], w2OI); + aggr.addW1W2(xh, w1h, w2h); + } + } else if (parameters[2] /* w0 */!= null) { + // w0 + double w0 = HiveUtils.getDouble(parameters[2], w0OI); + aggr.addW0(w0); + } + } + + @Override + public Object terminatePartial(@SuppressWarnings("deprecation") AggregationBuffer agg) + throws HiveException { + AggrBuffer aggr = (AggrBuffer) agg; + double v = aggr.get(); + return new DoubleWritable(v); + } + + @Override + public void merge(@SuppressWarnings("deprecation") AggregationBuffer agg, Object partial) + throws HiveException { + if (partial == null) { + return; + } + + AggrBuffer aggr = (AggrBuffer) agg; + DoubleWritable other = (DoubleWritable) partial; + double v = other.get(); + aggr.merge(v); + } + + @Override + public DoubleWritable terminate(@SuppressWarnings("deprecation") AggregationBuffer agg) + throws HiveException { + AggrBuffer aggr = (AggrBuffer) agg; + double v = aggr.get(); + return new DoubleWritable(v); + } + + } + + @AggregationType(estimable = true) + static class AggrBuffer extends AbstractAggregationBuffer { + + double score; + + AggrBuffer() { + super(); + reset(); + } + + @Override + public int estimate() { + return JavaDataModel.PRIMITIVES2; + } + + void reset() { + this.score = 0.d; + } + + double get() { + return score; + } + + void addW0(@Nonnull double w0) { + this.score += w0; + } + + void addW1W2(final double xh, final double w1h, final double w2h) { + this.score += (w1h * xh + w2h * xh * xh); + } + + void addW3(final double xh, final double xk, final double w3hk) { + this.score += (w3hk * xh * xk); + } + + void merge(final double other) { + this.score += other; + } + + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/391e7f1c/core/src/main/java/hivemall/classifier/KernelExpansionPassiveAggressiveUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/classifier/KernelExpansionPassiveAggressiveUDTF.java b/core/src/main/java/hivemall/classifier/KernelExpansionPassiveAggressiveUDTF.java new file mode 100644 index 0000000..3e28932 --- /dev/null +++ b/core/src/main/java/hivemall/classifier/KernelExpansionPassiveAggressiveUDTF.java @@ -0,0 +1,378 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.classifier; + +import hivemall.common.LossFunctions; +import hivemall.model.FeatureValue; +import hivemall.model.PredictionModel; +import hivemall.model.PredictionResult; +import hivemall.utils.collections.Int2FloatOpenHashTable; +import hivemall.utils.collections.Int2FloatOpenHashTable.IMapIterator; +import hivemall.utils.hashing.HashFunction; +import hivemall.utils.lang.Preconditions; + +import java.util.ArrayList; +import java.util.List; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Options; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.io.FloatWritable; +import org.apache.hadoop.io.IntWritable; + +/** + * Degree-2 polynomial kernel expansion Passive Aggressive. + * + * <pre> + * Hideki Isozaki and Hideto Kazawa: Efficient Support Vector Classifiers for Named Entity Recognition, Proc.COLING, 2002 + * </pre> + */ +@Description(name = "train_kpa", + value = "_FUNC_(array<string|int|bigint> features, int label [, const string options])" + + " - returns a relation <h int, hk int, float w0, float w1, float w2, float w3>") +public final class KernelExpansionPassiveAggressiveUDTF extends BinaryOnlineClassifierUDTF { + + // ------------------------------------ + // Hyper parameters + private float _pkc; + // Algorithm + private Algorithm _algo; + + // ------------------------------------ + // Model parameters + + private float _w0; + private Int2FloatOpenHashTable _w1; + private Int2FloatOpenHashTable _w2; + private Int2FloatOpenHashTable _w3; + + // ------------------------------------ + + private float _loss; + + public KernelExpansionPassiveAggressiveUDTF() {} + + //@VisibleForTesting + float getLoss() {//only used for testing purposes at the moment + return _loss; + } + + @Override + protected Options getOptions() { + Options opts = new Options(); + opts.addOption("pkc", true, + "Constant c inside polynomial kernel K = (dot(xi,xj) + c)^2 [default 1.0]"); + opts.addOption("algo", "algorithm", true, + "Algorithm for calculating loss [pa, pa1 (default), pa2]"); + opts.addOption("c", "aggressiveness", true, + "Aggressiveness parameter C for PA-1 and PA-2 [default 1.0]"); + return opts; + } + + @Override + protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException { + float pkc = 1.f; + float c = 1.f; + String algo = "pa1"; + + final CommandLine cl = super.processOptions(argOIs); + if (cl != null) { + String pkc_str = cl.getOptionValue("pkc"); + if (pkc_str != null) { + pkc = Float.parseFloat(pkc_str); + } + String c_str = cl.getOptionValue("c"); + if (c_str != null) { + c = Float.parseFloat(c_str); + if (c <= 0.f) { + throw new UDFArgumentException("Aggressiveness parameter C must be C > 0: " + c); + } + } + algo = cl.getOptionValue("algo", algo); + } + + if ("pa1".equalsIgnoreCase(algo)) { + this._algo = new PA1(c); + } else if ("pa2".equalsIgnoreCase(algo)) { + this._algo = new PA2(c); + } else if ("pa".equalsIgnoreCase(algo)) { + this._algo = new PA(); + } else { + throw new UDFArgumentException("Unsupported algorithm: " + algo); + } + this._pkc = pkc; + + return cl; + } + + interface Algorithm { + float eta(final float loss, @Nonnull final PredictionResult margin); + } + + static class PA implements Algorithm { + + PA() {} + + @Override + public float eta(float loss, PredictionResult margin) { + return loss / margin.getSquaredNorm(); + } + } + + static class PA1 implements Algorithm { + private final float c; + + PA1(float c) { + this.c = c; + } + + @Override + public float eta(float loss, PredictionResult margin) { + float squared_norm = margin.getSquaredNorm(); + float eta = loss / squared_norm; + return Math.min(c, eta); + } + } + + static class PA2 implements Algorithm { + private final float c; + + PA2(float c) { + this.c = c; + } + + @Override + public float eta(float loss, PredictionResult margin) { + float squared_norm = margin.getSquaredNorm(); + float eta = loss / (squared_norm + (0.5f / c)); + return eta; + } + } + + @Override + protected PredictionModel createModel() { + this._w0 = 0.f; + this._w1 = new Int2FloatOpenHashTable(16384); + _w1.defaultReturnValue(0.f); + this._w2 = new Int2FloatOpenHashTable(16384); + _w2.defaultReturnValue(0.f); + this._w3 = new Int2FloatOpenHashTable(16384); + _w3.defaultReturnValue(0.f); + + return null; + } + + @Override + protected StructObjectInspector getReturnOI(ObjectInspector featureRawOI) { + ArrayList<String> fieldNames = new ArrayList<String>(); + ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); + + fieldNames.add("h"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); + fieldNames.add("w0"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector); + fieldNames.add("w1"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector); + fieldNames.add("w2"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector); + fieldNames.add("hk"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); + fieldNames.add("w3"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector); + + return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); + } + + @Nullable + FeatureValue[] parseFeatures(@Nonnull final List<?> features) { + final int size = features.size(); + if (size == 0) { + return null; + } + + final FeatureValue[] featureVector = new FeatureValue[size]; + for (int i = 0; i < size; i++) { + Object f = features.get(i); + if (f == null) { + continue; + } + FeatureValue fv = FeatureValue.parse(f, true); + featureVector[i] = fv; + } + return featureVector; + } + + @Override + protected void train(@Nonnull final FeatureValue[] features, final int label) { + final float y = label > 0 ? 1.f : -1.f; + + PredictionResult margin = calcScoreWithKernelAndNorm(features); + float p = margin.getScore(); + float loss = LossFunctions.hingeLoss(p, y); // 1.0 - y * p + this._loss = loss; + + if (loss > 0.f) { // y * p < 1 + updateKernel(y, loss, margin, features); + } + } + + @Override + float predict(@Nonnull final FeatureValue[] features) { + float score = 0.f; + + for (int i = 0; i < features.length; ++i) { + if (features[i] == null) { + continue; + } + int h = features[i].getFeatureAsInt(); + float w1 = _w1.get(h); + float w2 = _w2.get(h); + double xi = features[i].getValue(); + double xx = xi * xi; + score += w1 * xi; + score += w2 * xx; + for (int j = i + 1; j < features.length; ++j) { + int k = features[j].getFeatureAsInt(); + int hk = HashFunction.hash(h, k, true); + float w3 = _w3.get(hk); + double xj = features[j].getValue(); + score += xi * xj * w3; + } + } + + return score; + } + + @Nonnull + final PredictionResult calcScoreWithKernelAndNorm(@Nonnull final FeatureValue[] features) { + float score = _w0; + float norm = 0.f; + for (int i = 0; i < features.length; ++i) { + if (features[i] == null) { + continue; + } + int h = features[i].getFeatureAsInt(); + float w1 = _w1.get(h); + float w2 = _w2.get(h); + double xi = features[i].getValue(); + double xx = xi * xi; + score += w1 * xi; + score += w2 * xx; + norm += xx; + for (int j = i + 1; j < features.length; ++j) { + int k = features[j].getFeatureAsInt(); + int hk = HashFunction.hash(h, k, true); + float w3 = _w3.get(hk); + double xj = features[j].getValue(); + score += xi * xj * w3; + } + } + return new PredictionResult(score).squaredNorm(norm); + } + + protected void updateKernel(final float label, final float loss, + @Nonnull final PredictionResult margin, @Nonnull final FeatureValue[] features) { + float eta = _algo.eta(loss, margin); + float coeff = eta * label; + expandKernel(features, coeff); + } + + private void expandKernel(@Nonnull final FeatureValue[] supportVector, final float alpha) { + final float pkc = _pkc; + // W0 += α c^2 + this._w0 += alpha * pkc * pkc; + + for (int i = 0; i < supportVector.length; ++i) { + final FeatureValue si = supportVector[i]; + final int h = si.getFeatureAsInt(); + float Zih = si.getValueAsFloat(); + + float alphaZih = alpha * Zih; + final float alphaZih2 = alphaZih * 2.f; + + // W1[h] += 2 c α Zi[h] + _w1.put(h, _w1.get(h) + pkc * alphaZih2); + // W2[h] += α Zi[h]^2 + _w2.put(h, _w2.get(h) + alphaZih * Zih); + + for (int j = i + 1; j < supportVector.length; ++j) { + FeatureValue sj = supportVector[j]; + int k = sj.getFeatureAsInt(); + int hk = HashFunction.hash(h, k, true); + float Zjk = sj.getValueAsFloat(); + + // W3 += 2 α Zi[h] Zi[k] + _w3.put(hk, _w3.get(hk) + alphaZih2 * Zjk); + } + } + } + + @Override + public void close() throws HiveException { + final IntWritable h = new IntWritable(0); // row[0] + final FloatWritable w0 = new FloatWritable(_w0); // row[1] + final FloatWritable w1 = new FloatWritable(); // row[2] + final FloatWritable w2 = new FloatWritable(); // row[3] + final IntWritable hk = new IntWritable(0); // row[4] + final FloatWritable w3 = new FloatWritable(); // row[5] + final Object[] row = new Object[] {h, w0, null, null, null, null}; + forward(row); // 0(f), w0 + row[1] = null; + + row[2] = w1; + row[3] = w2; + final Int2FloatOpenHashTable w2map = _w2; + final IMapIterator w1itor = _w1.entries(); + while (w1itor.next() != -1) { + int k = w1itor.getKey(); + Preconditions.checkArgument(k > 0, HiveException.class); + h.set(k); + w1.set(w1itor.getValue()); + w2.set(w2map.get(k)); + forward(row); // h(f), w1, w2 + } + this._w1 = null; + this._w2 = null; + + row[0] = null; + row[2] = null; + row[3] = null; + row[4] = hk; + row[5] = w3; + final IMapIterator w3itor = _w3.entries(); + while (w3itor.next() != -1) { + int k = w3itor.getKey(); + Preconditions.checkArgument(k > 0, HiveException.class); + hk.set(k); + w3.set(w3itor.getValue()); + forward(row); // hk(f), w3 + } + this._w3 = null; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/391e7f1c/core/src/main/java/hivemall/ensemble/bagging/VotedAvgUDAF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/ensemble/bagging/VotedAvgUDAF.java b/core/src/main/java/hivemall/ensemble/bagging/VotedAvgUDAF.java index 8b40142..a9bcab1 100644 --- a/core/src/main/java/hivemall/ensemble/bagging/VotedAvgUDAF.java +++ b/core/src/main/java/hivemall/ensemble/bagging/VotedAvgUDAF.java @@ -18,6 +18,8 @@ */ package hivemall.ensemble.bagging; +import javax.annotation.Nullable; + import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDAF; import org.apache.hadoop.hive.ql.exec.UDAFEvaluator; @@ -53,11 +55,15 @@ public final class VotedAvgUDAF extends UDAF { this.partial = null; } - public boolean iterate(double w) { + public boolean iterate(@Nullable DoubleWritable o) { + if (o == null) { + return true; + } if (partial == null) { this.partial = new PartialResult(); partial.init(); } + double w = o.get(); if (w > 0) { partial.positiveSum += w; partial.positiveCnt++; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/391e7f1c/core/src/main/java/hivemall/ensemble/bagging/WeightVotedAvgUDAF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/ensemble/bagging/WeightVotedAvgUDAF.java b/core/src/main/java/hivemall/ensemble/bagging/WeightVotedAvgUDAF.java index a7d63be..4e7ea1b 100644 --- a/core/src/main/java/hivemall/ensemble/bagging/WeightVotedAvgUDAF.java +++ b/core/src/main/java/hivemall/ensemble/bagging/WeightVotedAvgUDAF.java @@ -18,6 +18,8 @@ */ package hivemall.ensemble.bagging; +import javax.annotation.Nullable; + import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDAF; import org.apache.hadoop.hive.ql.exec.UDAFEvaluator; @@ -54,11 +56,15 @@ public final class WeightVotedAvgUDAF extends UDAF { this.partial = null; } - public boolean iterate(double w) { + public boolean iterate(@Nullable DoubleWritable o) { + if (o == null) { + return true; + } if (partial == null) { this.partial = new PartialResult(); partial.init(); } + double w = o.get(); if (w > 0) { partial.positiveSum += w; partial.positiveCnt++; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/391e7f1c/core/src/main/java/hivemall/ftvec/pairing/FeaturePairsUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/ftvec/pairing/FeaturePairsUDTF.java b/core/src/main/java/hivemall/ftvec/pairing/FeaturePairsUDTF.java new file mode 100644 index 0000000..6aebd64 --- /dev/null +++ b/core/src/main/java/hivemall/ftvec/pairing/FeaturePairsUDTF.java @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.ftvec.pairing; + +import hivemall.UDTFWithOptions; +import hivemall.model.FeatureValue; +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.hashing.HashFunction; +import hivemall.utils.lang.Preconditions; + +import java.util.ArrayList; +import java.util.List; + +import javax.annotation.Nonnull; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Options; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Writable; + +@Description(name = "feature_pairs", + value = "_FUNC_(feature_vector in array<string>, [, const string options])" + + " - Returns a relation <string i, string j, double xi, double xj>") +public final class FeaturePairsUDTF extends UDTFWithOptions { + + private Type _type; + private RowProcessor _proc; + + public FeaturePairsUDTF() {} + + @Override + protected Options getOptions() { + Options opts = new Options(); + opts.addOption("kpa", false, + "Generate feature pairs for Kernel-Expansion Passive Aggressive [default:true]"); + opts.addOption("ffm", false, + "Generate feature pairs for Field-aware Factorization Machines [default:false]"); + return opts; + } + + @Override + protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException { + CommandLine cl = null; + if (argOIs.length == 2) { + String args = HiveUtils.getConstString(argOIs[1]); + cl = parseOptions(args); + + Preconditions.checkArgument(cl.getOptions().length == 1, UDFArgumentException.class, + "Only one option can be specified: " + cl.getArgList()); + + if (cl.hasOption("kpa")) { + this._type = Type.kpa; + } else if (cl.hasOption("ffm")) { + this._type = Type.ffm; + } else { + throw new UDFArgumentException("Unsupported option: " + cl.getArgList().get(0)); + } + } else { + throw new UDFArgumentException("MUST provide -kpa or -ffm in the option"); + } + + return cl; + } + + @Override + public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { + if (argOIs.length != 1 && argOIs.length != 2) { + throw new UDFArgumentException("_FUNC_ takes 1 or 2 arguments"); + } + processOptions(argOIs); + + ListObjectInspector fvOI = HiveUtils.asListOI(argOIs[0]); + HiveUtils.validateFeatureOI(fvOI.getListElementObjectInspector()); + + final List<String> fieldNames = new ArrayList<String>(4); + final List<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(4); + switch (_type) { + case kpa: { + this._proc = new KPAProcessor(fvOI); + fieldNames.add("h"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); + fieldNames.add("hk"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); + fieldNames.add("xh"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); + fieldNames.add("xk"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); + break; + } + case ffm: { + throw new UDFArgumentException("-ffm is not supported yet"); + //break; + } + default: + throw new UDFArgumentException("Illegal condition: " + _type); + } + return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); + } + + @Override + public void process(Object[] args) throws HiveException { + Object arg0 = args[0]; + if (arg0 == null) { + return; + } + _proc.process(arg0); + } + + public enum Type { + kpa, ffm; + } + + abstract class RowProcessor { + + @Nonnull + protected final ListObjectInspector fvOI; + + RowProcessor(@Nonnull ListObjectInspector fvOI) { + this.fvOI = fvOI; + } + + void process(@Nonnull Object arg) throws HiveException { + final int size = fvOI.getListLength(arg); + if (size == 0) { + return; + } + + final List<FeatureValue> features = new ArrayList<FeatureValue>(size); + for (int i = 0; i < size; i++) { + Object f = fvOI.getListElement(arg, i); + if (f == null) { + continue; + } + FeatureValue fv = FeatureValue.parse(f, true); + features.add(fv); + } + + process(features); + } + + abstract void process(@Nonnull List<FeatureValue> features) throws HiveException; + + } + + final class KPAProcessor extends RowProcessor { + + @Nonnull + private final IntWritable f0, f1; + @Nonnull + private final DoubleWritable f2, f3; + @Nonnull + private final Writable[] forward; + + KPAProcessor(@Nonnull ListObjectInspector fvOI) { + super(fvOI); + this.f0 = new IntWritable(); + this.f1 = new IntWritable(); + this.f2 = new DoubleWritable(); + this.f3 = new DoubleWritable(); + this.forward = new Writable[] {f0, null, null, null}; + } + + @Override + void process(@Nonnull List<FeatureValue> features) throws HiveException { + forward[0] = f0; + f0.set(0); + forward[1] = null; + forward[2] = null; + forward[3] = null; + forward(forward); // forward h(f0) + + forward[2] = f2; + for (int i = 0, len = features.size(); i < len; i++) { + FeatureValue xi = features.get(i); + int h = xi.getFeatureAsInt(); + double xh = xi.getValue(); + forward[0] = f0; + f0.set(h); + forward[1] = null; + f2.set(xh); + forward[3] = null; + forward(forward); // forward h(f0), xh(f2) + + forward[0] = null; + forward[1] = f1; + forward[3] = f3; + for (int j = i + 1; j < len; j++) { + FeatureValue xj = features.get(j); + int k = xj.getFeatureAsInt(); + int hk = HashFunction.hash(h, k, true); + double xk = xj.getValue(); + f1.set(hk); + f3.set(xk); + forward(forward);// forward hk(f1), xh(f2), xk(f3) + } + } + } + } + + + @Override + public void close() throws HiveException { + // clean up to help GC + this._proc = null; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/391e7f1c/core/src/main/java/hivemall/model/FeatureValue.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/FeatureValue.java b/core/src/main/java/hivemall/model/FeatureValue.java index 8f2b728..7ff3383 100644 --- a/core/src/main/java/hivemall/model/FeatureValue.java +++ b/core/src/main/java/hivemall/model/FeatureValue.java @@ -18,6 +18,9 @@ */ package hivemall.model; +import hivemall.utils.hashing.MurmurHash3; +import hivemall.utils.lang.Preconditions; + import javax.annotation.Nonnull; import javax.annotation.Nullable; @@ -44,6 +47,12 @@ public final class FeatureValue { public <T> T getFeature() { return (T) feature; } + + public int getFeatureAsInt() { + Preconditions.checkNotNull(feature); + Preconditions.checkArgument(feature instanceof Integer); + return ((Integer) feature).intValue(); + } public double getValue() { return value; @@ -63,30 +72,42 @@ public final class FeatureValue { @Nullable public static FeatureValue parse(final Object o) throws IllegalArgumentException { + return parse(o, false); + } + + @Nullable + public static FeatureValue parse(final Object o, final boolean mhash) + throws IllegalArgumentException { if (o == null) { return null; } String s = o.toString(); - return parse(s); + return parse(s, mhash); } @Nullable public static FeatureValue parse(@Nonnull final String s) throws IllegalArgumentException { + return parse(s, false); + } + + @Nullable + public static FeatureValue parse(@Nonnull final String s, final boolean mhash) + throws IllegalArgumentException { assert (s != null); final int pos = s.indexOf(':'); if (pos == 0) { throw new IllegalArgumentException("Invalid feature value representation: " + s); } - final Text feature; + final Object feature; final double weight; if (pos > 0) { String s1 = s.substring(0, pos); String s2 = s.substring(pos + 1); - feature = new Text(s1); + feature = mhash ? Integer.valueOf(MurmurHash3.murmurhash3(s1)) : new Text(s1); weight = Double.parseDouble(s2); } else { - feature = new Text(s); + feature = mhash ? Integer.valueOf(MurmurHash3.murmurhash3(s)) : new Text(s); weight = 1.d; } return new FeatureValue(feature, weight); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/391e7f1c/core/src/main/java/hivemall/utils/collections/FloatArrayList.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/collections/FloatArrayList.java b/core/src/main/java/hivemall/utils/collections/FloatArrayList.java new file mode 100644 index 0000000..cfdf504 --- /dev/null +++ b/core/src/main/java/hivemall/utils/collections/FloatArrayList.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.collections; + +import java.io.Serializable; + +public final class FloatArrayList implements Serializable { + private static final long serialVersionUID = 8764828070342317585L; + + public static final int DEFAULT_CAPACITY = 12; + + /** array entity */ + private float[] data; + private int used; + + public FloatArrayList() { + this(DEFAULT_CAPACITY); + } + + public FloatArrayList(int size) { + this.data = new float[size]; + this.used = 0; + } + + public FloatArrayList(float[] initValues) { + this.data = initValues; + this.used = initValues.length; + } + + public void add(float value) { + if (used >= data.length) { + expand(used + 1); + } + data[used++] = value; + } + + public void add(float[] values) { + final int needs = used + values.length; + if (needs >= data.length) { + expand(needs); + } + System.arraycopy(values, 0, data, used, values.length); + this.used = needs; + } + + /** + * dynamic expansion. + */ + private void expand(int max) { + while (data.length < max) { + final int len = data.length; + float[] newArray = new float[len * 2]; + System.arraycopy(data, 0, newArray, 0, len); + this.data = newArray; + } + } + + public float remove() { + return data[--used]; + } + + public float remove(int index) { + final float ret; + if (index > used) { + throw new IndexOutOfBoundsException(); + } else if (index == used) { + ret = data[--used]; + } else { // index < used + // removed value + ret = data[index]; + final float[] newarray = new float[--used]; + // prefix + System.arraycopy(data, 0, newarray, 0, index - 1); + // appendix + System.arraycopy(data, index + 1, newarray, index, used - index); + // set fields. + this.data = newarray; + } + return ret; + } + + public void set(int index, float value) { + if (index > used) { + throw new IllegalArgumentException("Index MUST be less than \"size()\"."); + } else if (index == used) { + ++used; + } + data[index] = value; + } + + public float get(int index) { + if (index >= used) + throw new IndexOutOfBoundsException(); + return data[index]; + } + + public float fastGet(int index) { + return data[index]; + } + + public int size() { + return used; + } + + public boolean isEmpty() { + return used == 0; + } + + public void clear() { + this.used = 0; + } + + public float[] toArray() { + final float[] newArray = new float[used]; + System.arraycopy(data, 0, newArray, 0, used); + return newArray; + } + + public float[] array() { + return data; + } + + @Override + public String toString() { + final StringBuilder buf = new StringBuilder(); + buf.append('['); + for (int i = 0; i < used; i++) { + if (i != 0) { + buf.append(", "); + } + buf.append(data[i]); + } + buf.append(']'); + return buf.toString(); + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/391e7f1c/core/src/main/java/hivemall/utils/collections/Int2FloatOpenHashTable.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/collections/Int2FloatOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/Int2FloatOpenHashTable.java index 49f27c8..a06cdb0 100644 --- a/core/src/main/java/hivemall/utils/collections/Int2FloatOpenHashTable.java +++ b/core/src/main/java/hivemall/utils/collections/Int2FloatOpenHashTable.java @@ -73,6 +73,9 @@ public class Int2FloatOpenHashTable implements Externalizable { this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR, true); } + /** + * Only for {@link Externalizable} + */ public Int2FloatOpenHashTable() {// required for serialization this._loadFactor = DEFAULT_LOAD_FACTOR; this._growFactor = DEFAULT_GROW_FACTOR; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/391e7f1c/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java index 8188b7a..5423c9d 100644 --- a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java +++ b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java @@ -308,6 +308,18 @@ public final class HiveUtils { } } + public static boolean isStringTypeInfo(@Nonnull TypeInfo typeInfo) { + if (typeInfo.getCategory() != ObjectInspector.Category.PRIMITIVE) { + return false; + } + switch (((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory()) { + case STRING: + return true; + default: + return false; + } + } + public static boolean isConstString(@Nonnull final ObjectInspector oi) { return ObjectInspectorUtils.isConstantObjectInspector(oi) && isStringOI(oi); } @@ -321,6 +333,20 @@ public final class HiveUtils { return (ListTypeInfo) typeInfo; } + public static float getFloat(@Nullable Object o, @Nonnull PrimitiveObjectInspector oi) { + if (o == null) { + return 0.f; + } + return PrimitiveObjectInspectorUtils.getFloat(o, oi); + } + + public static double getDouble(@Nullable Object o, @Nonnull PrimitiveObjectInspector oi) { + if (o == null) { + return 0.d; + } + return PrimitiveObjectInspectorUtils.getDouble(o, oi); + } + @SuppressWarnings("unchecked") @Nullable public static <T extends Writable> T getConstValue(@Nonnull final ObjectInspector oi) @@ -776,6 +802,7 @@ public final class HiveUtils { return oi; } + @Nonnull public static PrimitiveObjectInspector asDoubleCompatibleOI(@Nonnull final ObjectInspector argOI) throws UDFArgumentTypeException { if (argOI.getCategory() != Category.PRIMITIVE) { @@ -802,6 +829,30 @@ public final class HiveUtils { } @Nonnull + public static PrimitiveObjectInspector asNumberOI(@Nonnull final ObjectInspector argOI) + throws UDFArgumentTypeException { + if (argOI.getCategory() != Category.PRIMITIVE) { + throw new UDFArgumentTypeException(0, "Only primitive type arguments are accepted but " + + argOI.getTypeName() + " is passed."); + } + final PrimitiveObjectInspector oi = (PrimitiveObjectInspector) argOI; + switch (oi.getPrimitiveCategory()) { + case BYTE: + case SHORT: + case INT: + case LONG: + case FLOAT: + case DOUBLE: + break; + default: + throw new UDFArgumentTypeException(0, + "Only numeric or string type arguments are accepted but " + argOI.getTypeName() + + " is passed."); + } + return oi; + } + + @Nonnull public static ListObjectInspector asListOI(@Nonnull final ObjectInspector oi) throws UDFArgumentException { Category category = oi.getCategory(); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/391e7f1c/core/src/main/java/hivemall/utils/hashing/HashFunction.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/hashing/HashFunction.java b/core/src/main/java/hivemall/utils/hashing/HashFunction.java index eb44915..bd1b841 100644 --- a/core/src/main/java/hivemall/utils/hashing/HashFunction.java +++ b/core/src/main/java/hivemall/utils/hashing/HashFunction.java @@ -18,6 +18,8 @@ */ package hivemall.utils.hashing; +import hivemall.utils.math.MathUtils; + public abstract class HashFunction { public int hash(Object data) { @@ -25,6 +27,18 @@ public abstract class HashFunction { return hash(s); } + public static int hash(final int first, final int second, final boolean positive) { + final int h = first * 157 + second; + if (positive) { + int r = MathUtils.moduloPowerOfTwo(h, MurmurHash3.DEFAULT_NUM_FEATURES); + if (r < 0) { + r += MurmurHash3.DEFAULT_NUM_FEATURES; + } + return r; + } + return h; + } + public abstract int hash(String data); } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/391e7f1c/core/src/main/java/hivemall/utils/lang/Preconditions.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/lang/Preconditions.java b/core/src/main/java/hivemall/utils/lang/Preconditions.java index af63127..eabbc0a 100644 --- a/core/src/main/java/hivemall/utils/lang/Preconditions.java +++ b/core/src/main/java/hivemall/utils/lang/Preconditions.java @@ -98,19 +98,21 @@ public final class Preconditions { } public static <E extends Throwable> void checkArgument(boolean expression, - @Nonnull String errorMessage, @Nonnull Class<E> clazz) throws E { + @Nonnull Class<E> clazz, @Nullable Object errorMessage) throws E { if (!expression) { + Constructor<E> constructor; + try { + constructor = clazz.getConstructor(String.class); + } catch (NoSuchMethodException | SecurityException e) { + throw new IllegalStateException( + "Failed to get a constructor of " + clazz.getName(), e); + } final E throwable; try { - Constructor<E> constructor = clazz.getConstructor(String.class); throwable = constructor.newInstance(errorMessage); - } catch (NoSuchMethodException | SecurityException e1) { - throw new IllegalStateException("Failed to get a Constructor(String): " - + clazz.getName(), e1); - } catch (InstantiationException | IllegalAccessException | IllegalArgumentException - | InvocationTargetException e2) { + } catch (ReflectiveOperationException | IllegalArgumentException e) { throw new IllegalStateException( - "Failed to instantiate a class: " + clazz.getName(), e2); + "Failed to instantiate a class: " + clazz.getName(), e); } throw throwable; } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/391e7f1c/core/src/main/java/hivemall/utils/math/MathUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/math/MathUtils.java b/core/src/main/java/hivemall/utils/math/MathUtils.java index ae9f029..252ccf6 100644 --- a/core/src/main/java/hivemall/utils/math/MathUtils.java +++ b/core/src/main/java/hivemall/utils/math/MathUtils.java @@ -89,8 +89,7 @@ public final class MathUtils { } /** - * <a href="https://en.wikipedia.org/wiki/Logit">Logit</a> is the inverse of - * {@link #sigmoid(double)} function. + * <a href="https://en.wikipedia.org/wiki/Logit">Logit</a> is the inverse of {@link #sigmoid(double)} function. */ public static double logit(final double p) { return Math.log(p / (1.d - p)); @@ -101,14 +100,11 @@ public final class MathUtils { } /** - * Returns the inverse erf. This code is based on erfInv() in - * org.apache.commons.math3.special.Erf. + * Returns the inverse erf. This code is based on erfInv() in org.apache.commons.math3.special.Erf. * <p> - * This implementation is described in the paper: <a - * href="http://people.maths.ox.ac.uk/gilesm/files/gems_erfinv.pdf">Approximating the erfinv - * function</a> by Mike Giles, Oxford-Man Institute of Quantitative Finance, which was published - * in GPU Computing Gems, volume 2, 2010. The source code is available <a - * href="http://gpucomputing.net/?q=node/1828">here</a>. + * This implementation is described in the paper: <a href="http://people.maths.ox.ac.uk/gilesm/files/gems_erfinv.pdf">Approximating the erfinv + * function</a> by Mike Giles, Oxford-Man Institute of Quantitative Finance, which was published in GPU Computing Gems, volume 2, 2010. The source + * code is available <a href="http://gpucomputing.net/?q=node/1828">here</a>. * </p> * * @param x the value @@ -227,8 +223,8 @@ public final class MathUtils { return v < 0 ? -1 : 1; } - public static float sign(final float v) { - return v < 0.f ? -1.f : 1.f; + public static int sign(final float v) { + return v < 0.f ? -1 : 1; } public static double log(final double n, final int base) { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/391e7f1c/core/src/test/java/hivemall/anomaly/ChangeFinder1DTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/anomaly/ChangeFinder1DTest.java b/core/src/test/java/hivemall/anomaly/ChangeFinder1DTest.java index d7bac75..1470ffb 100644 --- a/core/src/test/java/hivemall/anomaly/ChangeFinder1DTest.java +++ b/core/src/test/java/hivemall/anomaly/ChangeFinder1DTest.java @@ -46,7 +46,7 @@ public class ChangeFinder1DTest { ChangeFinder1D cf = new ChangeFinder1D(params, oi); double[] outScores = new double[2]; - BufferedReader reader = readFile("cf1d.csv"); + BufferedReader reader = readFile("cf1d.csv.gz"); println("x outlier change"); String line; int numOutliers = 0, numChangepoints = 0; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/391e7f1c/core/src/test/java/hivemall/anomaly/ChangeFinder2DTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/anomaly/ChangeFinder2DTest.java b/core/src/test/java/hivemall/anomaly/ChangeFinder2DTest.java index 240906c..43a0921 100644 --- a/core/src/test/java/hivemall/anomaly/ChangeFinder2DTest.java +++ b/core/src/test/java/hivemall/anomaly/ChangeFinder2DTest.java @@ -60,7 +60,7 @@ public class ChangeFinder2DTest { double[] outScores = new double[2]; List<Double> x = new ArrayList<Double>(1); - BufferedReader reader = readFile("cf1d.csv"); + BufferedReader reader = readFile("cf1d.csv.gz"); println("x outlier change"); String line; int i = 1, numOutliers = 0, numChangepoints = 0; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/391e7f1c/core/src/test/java/hivemall/classifier/KernelExpansionPassiveAggressiveUDTFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/classifier/KernelExpansionPassiveAggressiveUDTFTest.java b/core/src/test/java/hivemall/classifier/KernelExpansionPassiveAggressiveUDTFTest.java new file mode 100644 index 0000000..a5c98e1 --- /dev/null +++ b/core/src/test/java/hivemall/classifier/KernelExpansionPassiveAggressiveUDTFTest.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.classifier; + +import hivemall.model.FeatureValue; +import hivemall.utils.math.MathUtils; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.text.ParseException; +import java.util.ArrayList; +import java.util.StringTokenizer; +import java.util.zip.GZIPInputStream; + +import javax.annotation.Nonnull; + +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.junit.Assert; +import org.junit.Test; + +public class KernelExpansionPassiveAggressiveUDTFTest { + + @Test + public void testNews20() throws IOException, ParseException, HiveException { + KernelExpansionPassiveAggressiveUDTF udtf = new KernelExpansionPassiveAggressiveUDTF(); + ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector; + ObjectInspector stringOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector; + ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector(stringOI); + udtf.initialize(new ObjectInspector[] {stringListOI, intOI}); + + BufferedReader news20 = readFile("news20-small.binary.gz"); + ArrayList<String> words = new ArrayList<String>(); + String line = news20.readLine(); + while (line != null) { + StringTokenizer tokens = new StringTokenizer(line, " "); + int label = Integer.parseInt(tokens.nextToken()); + while (tokens.hasMoreTokens()) { + words.add(tokens.nextToken()); + } + Assert.assertFalse(words.isEmpty()); + udtf.process(new Object[] {words, label}); + + words.clear(); + line = news20.readLine(); + } + + Assert.assertTrue(Math.abs(udtf.getLoss()) < 0.1f); + + news20.close(); + } + + public void test_a9a() throws IOException, ParseException, HiveException { + KernelExpansionPassiveAggressiveUDTF udtf = new KernelExpansionPassiveAggressiveUDTF(); + ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector; + ObjectInspector stringOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector; + ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector(stringOI); + ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-c 0.01"); + + udtf.initialize(new ObjectInspector[] {stringListOI, intOI, params}); + + final ArrayList<String> words = new ArrayList<String>(); + BufferedReader trainData = readFile("a9a.gz"); + String line = trainData.readLine(); + while (line != null) { + StringTokenizer tokens = new StringTokenizer(line, " "); + String labelStr = tokens.nextToken(); + final int label; + if ("+1".equals(labelStr)) { + label = 1; + } else if ("-1".equals(labelStr)) { + label = -1; + } else { + throw new IllegalStateException("Illegal label: " + labelStr); + } + while (tokens.hasMoreTokens()) { + words.add(tokens.nextToken()); + } + Assert.assertFalse(words.isEmpty()); + udtf.process(new Object[] {words, label}); + + words.clear(); + line = trainData.readLine(); + } + trainData.close(); + + int numTests = 0; + int numCorrect = 0; + + BufferedReader testData = readFile("a9a.t.gz"); + line = testData.readLine(); + while (line != null) { + StringTokenizer tokens = new StringTokenizer(line, " "); + String labelStr = tokens.nextToken(); + final int actual; + if ("+1".equals(labelStr)) { + actual = 1; + } else if ("-1".equals(labelStr)) { + actual = -1; + } else { + throw new IllegalStateException("Illegal label: " + labelStr); + } + while (tokens.hasMoreTokens()) { + words.add(tokens.nextToken()); + } + Assert.assertFalse(words.isEmpty()); + + FeatureValue[] features = udtf.parseFeatures(words); + float score = udtf.predict(features); + int predicted = MathUtils.sign(score); + + if (predicted == actual) { + ++numCorrect; + } + ++numTests; + + words.clear(); + line = testData.readLine(); + } + testData.close(); + + float accuracy = numCorrect / (float) numTests; + Assert.assertTrue(accuracy > 0.82f); + } + + @Nonnull + private static BufferedReader readFile(@Nonnull String fileName) throws IOException { + InputStream is = KernelExpansionPassiveAggressiveUDTFTest.class.getResourceAsStream(fileName); + if (fileName.endsWith(".gz")) { + is = new GZIPInputStream(is); + } + return new BufferedReader(new InputStreamReader(is)); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/391e7f1c/core/src/test/java/hivemall/fm/FactorizationMachineUDTFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/fm/FactorizationMachineUDTFTest.java b/core/src/test/java/hivemall/fm/FactorizationMachineUDTFTest.java index eacfa8d..6d053de 100644 --- a/core/src/test/java/hivemall/fm/FactorizationMachineUDTFTest.java +++ b/core/src/test/java/hivemall/fm/FactorizationMachineUDTFTest.java @@ -20,10 +20,14 @@ package hivemall.fm; import java.io.BufferedReader; import java.io.IOException; +import java.io.InputStream; import java.io.InputStreamReader; import java.util.ArrayList; import java.util.List; import java.util.StringTokenizer; +import java.util.zip.GZIPInputStream; + +import javax.annotation.Nonnull; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; @@ -56,8 +60,7 @@ public class FactorizationMachineUDTFTest { double loss = 0.d; double cumul = 0.d; for (int trainingIteration = 1; trainingIteration <= ITERATIONS; ++trainingIteration) { - BufferedReader data = new BufferedReader(new InputStreamReader( - getClass().getResourceAsStream("5107786.txt"))); + BufferedReader data = readFile("5107786.txt.gz"); loss = udtf._cvState.getCumulativeLoss(); int trExamples = 0; String line = data.readLine(); @@ -78,8 +81,17 @@ public class FactorizationMachineUDTFTest { println(trainingIteration + " " + loss + " " + cumul / (trainingIteration * trExamples)); data.close(); } + Assert.assertTrue("Loss was greater than 0.1: " + loss, loss <= 0.1); + } + @Nonnull + private static BufferedReader readFile(@Nonnull String fileName) throws IOException { + InputStream is = FactorizationMachineUDTFTest.class.getResourceAsStream(fileName); + if (fileName.endsWith(".gz")) { + is = new GZIPInputStream(is); + } + return new BufferedReader(new InputStreamReader(is)); } private static void println(String line) { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/391e7f1c/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java b/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java index f4d4f80..792ede1 100644 --- a/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java +++ b/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java @@ -20,8 +20,12 @@ package hivemall.fm; import java.io.BufferedReader; import java.io.IOException; +import java.io.InputStream; import java.io.InputStreamReader; import java.util.ArrayList; +import java.util.zip.GZIPInputStream; + +import javax.annotation.Nonnull; import org.apache.commons.lang.StringUtils; import org.apache.hadoop.hive.ql.metadata.HiveException; @@ -87,8 +91,7 @@ public class FieldAwareFactorizationMachineUDTFTest { double loss = 0.d; double cumul = 0.d; for (int trainingIteration = 1; trainingIteration <= ITERATIONS; ++trainingIteration) { - BufferedReader data = new BufferedReader(new InputStreamReader( - FieldAwareFactorizationMachineUDTFTest.class.getResourceAsStream("bigdata.tr.txt"))); + BufferedReader data = readFile("bigdata.tr.txt.gz"); loss = udtf._cvState.getCumulativeLoss(); int lines = 0; for (int lineNumber = 0; lineNumber < MAX_LINES; ++lineNumber, ++lines) { @@ -131,6 +134,15 @@ public class FieldAwareFactorizationMachineUDTFTest { Assert.assertTrue("Last loss was greater than expected: " + loss, loss < lossThreshold); } + @Nonnull + private static BufferedReader readFile(@Nonnull String fileName) throws IOException { + InputStream is = FieldAwareFactorizationMachineUDTFTest.class.getResourceAsStream(fileName); + if (fileName.endsWith(".gz")) { + is = new GZIPInputStream(is); + } + return new BufferedReader(new InputStreamReader(is)); + } + private static String[] toStringArray(ArrayList<StringFeature> x) { final int size = x.size(); final String[] ret = new String[size]; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/391e7f1c/core/src/test/java/hivemall/mf/BPRMatrixFactorizationUDTFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/mf/BPRMatrixFactorizationUDTFTest.java b/core/src/test/java/hivemall/mf/BPRMatrixFactorizationUDTFTest.java index e28d318..41f1f97 100644 --- a/core/src/test/java/hivemall/mf/BPRMatrixFactorizationUDTFTest.java +++ b/core/src/test/java/hivemall/mf/BPRMatrixFactorizationUDTFTest.java @@ -18,11 +18,13 @@ */ package hivemall.mf; -import hivemall.utils.io.IOUtils; import hivemall.utils.lang.StringUtils; import java.io.BufferedReader; import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.util.zip.GZIPInputStream; import javax.annotation.Nonnull; @@ -63,7 +65,7 @@ public class BPRMatrixFactorizationUDTFTest { final IntWritable negItem = new IntWritable(); final Object[] args = new Object[] {user, posItem, negItem}; - BufferedReader train = IOUtils.bufferedReader(getClass().getResourceAsStream("ml1k.train")); + BufferedReader train = readFile("ml1k.train.gz"); String line; while ((line = train.readLine()) != null) { parseLine(line, user, posItem, negItem); @@ -98,7 +100,7 @@ public class BPRMatrixFactorizationUDTFTest { final IntWritable negItem = new IntWritable(); final Object[] args = new Object[] {user, posItem, negItem}; - BufferedReader train = IOUtils.bufferedReader(getClass().getResourceAsStream("ml1k.train")); + BufferedReader train = readFile("ml1k.train.gz"); String line; while ((line = train.readLine()) != null) { parseLine(line, user, posItem, negItem); @@ -109,6 +111,14 @@ public class BPRMatrixFactorizationUDTFTest { Assert.assertTrue("finishedIter: " + finishedIter, finishedIter < iterations); } + @Nonnull + private static BufferedReader readFile(@Nonnull String fileName) throws IOException { + InputStream is = BPRMatrixFactorizationUDTFTest.class.getResourceAsStream(fileName); + if (fileName.endsWith(".gz")) { + is = new GZIPInputStream(is); + } + return new BufferedReader(new InputStreamReader(is)); + } private static void parseLine(@Nonnull String line, @Nonnull IntWritable user, @Nonnull IntWritable posItem, @Nonnull IntWritable negItem) { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/391e7f1c/core/src/test/java/hivemall/utils/collections/Int2FloatOpenHashMapTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/utils/collections/Int2FloatOpenHashMapTest.java b/core/src/test/java/hivemall/utils/collections/Int2FloatOpenHashMapTest.java index 14635be..8a8a68d 100644 --- a/core/src/test/java/hivemall/utils/collections/Int2FloatOpenHashMapTest.java +++ b/core/src/test/java/hivemall/utils/collections/Int2FloatOpenHashMapTest.java @@ -80,4 +80,17 @@ public class Int2FloatOpenHashMapTest { Assert.assertEquals(-1, itor.next()); } + @Test + public void testIterator2() { + Int2FloatOpenHashTable map = new Int2FloatOpenHashTable(100); + map.put(33, 3.16f); + + Int2FloatOpenHashTable.IMapIterator itor = map.entries(); + Assert.assertTrue(itor.hasNext()); + Assert.assertNotEquals(-1, itor.next()); + Assert.assertEquals(33, itor.getKey()); + Assert.assertEquals(3.16f, itor.getValue(), 0.d); + Assert.assertEquals(-1, itor.next()); + } + } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/391e7f1c/core/src/test/java/hivemall/utils/io/Base91OutputStreamTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/utils/io/Base91OutputStreamTest.java b/core/src/test/java/hivemall/utils/io/Base91OutputStreamTest.java index 99a132f..c7f41f9 100644 --- a/core/src/test/java/hivemall/utils/io/Base91OutputStreamTest.java +++ b/core/src/test/java/hivemall/utils/io/Base91OutputStreamTest.java @@ -50,7 +50,7 @@ public class Base91OutputStreamTest { @Test public void testLargeEncodeOutDecodeIn() throws IOException { - InputStream in = ArrayModelTest.class.getResourceAsStream("bigdata.tr.txt"); + InputStream in = ArrayModelTest.class.getResourceAsStream("bigdata.tr.txt.gz"); byte[] expected = IOUtils.toByteArray(in); FastByteArrayOutputStream bos = new FastByteArrayOutputStream(); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/391e7f1c/core/src/test/java/hivemall/utils/lang/PreconditionsTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/utils/lang/PreconditionsTest.java b/core/src/test/java/hivemall/utils/lang/PreconditionsTest.java index 4cbbd3f..cf27a75 100644 --- a/core/src/test/java/hivemall/utils/lang/PreconditionsTest.java +++ b/core/src/test/java/hivemall/utils/lang/PreconditionsTest.java @@ -52,7 +52,7 @@ public class PreconditionsTest { public void testCheckArgumentBooleanClassOfE2() { final String msg = "safdfvzfd"; try { - Preconditions.checkArgument(false, msg, HiveException.class); + Preconditions.checkArgument(false, HiveException.class, msg); } catch (HiveException e) { if (e.getMessage().equals(msg)) { return;