Close #66: [HIVEMALL-91] Implement Online LDA
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/9b2ddcc7 Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/9b2ddcc7 Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/9b2ddcc7 Branch: refs/heads/master Commit: 9b2ddcc76b0950124373a30c1dbc56acff664ebf Parents: bba252a Author: Takuya Kitazawa <k.tak...@gmail.com> Authored: Thu Apr 20 16:33:20 2017 +0900 Committer: myui <yuin...@gmail.com> Committed: Thu Apr 20 16:33:20 2017 +0900 ---------------------------------------------------------------------- .../main/java/hivemall/model/FeatureValue.java | 4 + .../hivemall/topicmodel/LDAPredictUDAF.java | 476 ++++++++++++++++ .../main/java/hivemall/topicmodel/LDAUDTF.java | 567 +++++++++++++++++++ .../hivemall/topicmodel/OnlineLDAModel.java | 554 ++++++++++++++++++ .../java/hivemall/utils/lang/ArrayUtils.java | 20 + .../java/hivemall/utils/math/MathUtils.java | 43 ++ .../hivemall/topicmodel/LDAPredictUDAFTest.java | 228 ++++++++ .../java/hivemall/topicmodel/LDAUDTFTest.java | 104 ++++ .../hivemall/topicmodel/OnlineLDAModelTest.java | 252 +++++++++ docs/gitbook/SUMMARY.md | 8 +- docs/gitbook/clustering/lda.md | 195 +++++++ resources/ddl/define-all-as-permanent.hive | 10 + resources/ddl/define-all.hive | 10 + resources/ddl/define-all.spark | 10 + resources/ddl/define-udfs.td.hql | 2 + 15 files changed, 2481 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9b2ddcc7/core/src/main/java/hivemall/model/FeatureValue.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/FeatureValue.java b/core/src/main/java/hivemall/model/FeatureValue.java index 39fadaf..11aa8f0 100644 --- a/core/src/main/java/hivemall/model/FeatureValue.java +++ b/core/src/main/java/hivemall/model/FeatureValue.java @@ -54,6 +54,10 @@ public final class FeatureValue { return ((Integer) feature).intValue(); } + public String getFeatureAsString() { + return feature.toString(); + } + public double getValue() { return value; } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9b2ddcc7/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java b/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java new file mode 100644 index 0000000..811af2e --- /dev/null +++ b/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java @@ -0,0 +1,476 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.topicmodel; + +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.lang.CommandLineUtils; +import hivemall.utils.lang.Primitives; + +import java.io.PrintWriter; +import java.io.StringWriter; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.SortedMap; +import java.util.TreeMap; + +import javax.annotation.Nonnull; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.HelpFormatter; +import org.apache.commons.cli.Options; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StandardMapObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructField; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.io.FloatWritable; +import org.apache.hadoop.io.IntWritable; + +@Description(name = "lda_predict", + value = "_FUNC_(string word, float value, int label, float lambda[, const string options])" + + " - Returns a list which consists of <int label, float prob>") +public final class LDAPredictUDAF extends AbstractGenericUDAFResolver { + + @Override + public Evaluator getEvaluator(TypeInfo[] typeInfo) throws SemanticException { + if (typeInfo.length != 4 && typeInfo.length != 5) { + throw new UDFArgumentLengthException( + "Expected argument length is 4 or 5 but given argument length was " + + typeInfo.length); + } + + if (!HiveUtils.isStringTypeInfo(typeInfo[0])) { + throw new UDFArgumentTypeException(0, + "String type is expected for the first argument word: " + typeInfo[0].getTypeName()); + } + if (!HiveUtils.isNumberTypeInfo(typeInfo[1])) { + throw new UDFArgumentTypeException(1, + "Number type is expected for the second argument value: " + + typeInfo[1].getTypeName()); + } + if (!HiveUtils.isIntegerTypeInfo(typeInfo[2])) { + throw new UDFArgumentTypeException(2, + "Integer type is expected for the third argument label: " + + typeInfo[2].getTypeName()); + } + if (!HiveUtils.isNumberTypeInfo(typeInfo[3])) { + throw new UDFArgumentTypeException(3, + "Number type is expected for the forth argument lambda: " + + typeInfo[3].getTypeName()); + } + + if (typeInfo.length == 5) { + if (!HiveUtils.isStringTypeInfo(typeInfo[4])) { + throw new UDFArgumentTypeException(4, + "String type is expected for the fifth argument lambda: " + + typeInfo[4].getTypeName()); + } + } + + return new Evaluator(); + } + + public static class Evaluator extends GenericUDAFEvaluator { + + // input OI + private PrimitiveObjectInspector wordOI; + private PrimitiveObjectInspector valueOI; + private PrimitiveObjectInspector labelOI; + private PrimitiveObjectInspector lambdaOI; + + // Hyperparameters + private int topic; + private float alpha; + private double delta; + + // merge OI + private StructObjectInspector internalMergeOI; + private StructField wcListField; + private StructField lambdaMapField; + private StructField topicOptionField; + private StructField alphaOptionField; + private StructField deltaOptionField; + private PrimitiveObjectInspector wcListElemOI; + private StandardListObjectInspector wcListOI; + private StandardMapObjectInspector lambdaMapOI; + private PrimitiveObjectInspector lambdaMapKeyOI; + private StandardListObjectInspector lambdaMapValueOI; + private PrimitiveObjectInspector lambdaMapValueElemOI; + + public Evaluator() {} + + protected Options getOptions() { + Options opts = new Options(); + opts.addOption("k", "topic", true, "The number of topics [required]"); + opts.addOption("alpha", true, "The hyperparameter for theta [default: 1/k]"); + opts.addOption("delta", true, + "Check convergence in the expectation step [default: 1E-5]"); + return opts; + } + + @Nonnull + protected final CommandLine parseOptions(String optionValue) throws UDFArgumentException { + String[] args = optionValue.split("\\s+"); + Options opts = getOptions(); + opts.addOption("help", false, "Show function help"); + CommandLine cl = CommandLineUtils.parseOptions(args, opts); + + if (cl.hasOption("help")) { + Description funcDesc = getClass().getAnnotation(Description.class); + final String cmdLineSyntax; + if (funcDesc == null) { + cmdLineSyntax = getClass().getSimpleName(); + } else { + String funcName = funcDesc.name(); + cmdLineSyntax = funcName == null ? getClass().getSimpleName() + : funcDesc.value().replace("_FUNC_", funcDesc.name()); + } + StringWriter sw = new StringWriter(); + sw.write('\n'); + PrintWriter pw = new PrintWriter(sw); + HelpFormatter formatter = new HelpFormatter(); + formatter.printHelp(pw, HelpFormatter.DEFAULT_WIDTH, cmdLineSyntax, null, opts, + HelpFormatter.DEFAULT_LEFT_PAD, HelpFormatter.DEFAULT_DESC_PAD, null, true); + pw.flush(); + String helpMsg = sw.toString(); + throw new UDFArgumentException(helpMsg); + } + + return cl; + } + + protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException { + CommandLine cl = null; + + if (argOIs.length != 5) { + throw new UDFArgumentException("At least 1 option `-topic` MUST be specified"); + } + + String rawArgs = HiveUtils.getConstString(argOIs[4]); + cl = parseOptions(rawArgs); + + this.topic = Primitives.parseInt(cl.getOptionValue("topic"), 0); + if (topic < 1) { + throw new UDFArgumentException( + "A positive integer MUST be set to an option `-topic`: " + topic); + } + + this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), 1.f / topic); + this.delta = Primitives.parseDouble(cl.getOptionValue("delta"), 1E-5d); + + return cl; + } + + @Override + public ObjectInspector init(Mode mode, ObjectInspector[] parameters) throws HiveException { + assert (parameters.length == 4 || parameters.length == 5); + super.init(mode, parameters); + + // initialize input + if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data + processOptions(parameters); + this.wordOI = HiveUtils.asStringOI(parameters[0]); + this.valueOI = HiveUtils.asDoubleCompatibleOI(parameters[1]); + this.labelOI = HiveUtils.asIntegerOI(parameters[2]); + this.lambdaOI = HiveUtils.asDoubleCompatibleOI(parameters[3]); + } else {// from partial aggregation + StructObjectInspector soi = (StructObjectInspector) parameters[0]; + this.internalMergeOI = soi; + this.wcListField = soi.getStructFieldRef("wcList"); + this.lambdaMapField = soi.getStructFieldRef("lambdaMap"); + this.topicOptionField = soi.getStructFieldRef("topic"); + this.alphaOptionField = soi.getStructFieldRef("alpha"); + this.deltaOptionField = soi.getStructFieldRef("delta"); + this.wcListElemOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector; + this.wcListOI = ObjectInspectorFactory.getStandardListObjectInspector(wcListElemOI); + this.lambdaMapKeyOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector; + this.lambdaMapValueElemOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector; + this.lambdaMapValueOI = ObjectInspectorFactory.getStandardListObjectInspector(lambdaMapValueElemOI); + this.lambdaMapOI = ObjectInspectorFactory.getStandardMapObjectInspector( + lambdaMapKeyOI, lambdaMapValueOI); + } + + // initialize output + final ObjectInspector outputOI; + if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// terminatePartial + outputOI = internalMergeOI(); + } else { + final ArrayList<String> fieldNames = new ArrayList<String>(); + final ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); + fieldNames.add("label"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); + fieldNames.add("probability"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector); + + outputOI = ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorFactory.getStandardStructObjectInspector( + fieldNames, fieldOIs)); + } + return outputOI; + } + + private static StructObjectInspector internalMergeOI() { + ArrayList<String> fieldNames = new ArrayList<String>(); + ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); + + fieldNames.add("wcList"); + fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector)); + + fieldNames.add("lambdaMap"); + fieldOIs.add(ObjectInspectorFactory.getStandardMapObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaFloatObjectInspector))); + + fieldNames.add("topic"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); + + fieldNames.add("alpha"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector); + + fieldNames.add("delta"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); + + return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); + } + + @SuppressWarnings("deprecation") + @Override + public AggregationBuffer getNewAggregationBuffer() throws HiveException { + AggregationBuffer myAggr = new OnlineLDAPredictAggregationBuffer(); + reset(myAggr); + return myAggr; + } + + @Override + public void reset(@SuppressWarnings("deprecation") AggregationBuffer agg) + throws HiveException { + OnlineLDAPredictAggregationBuffer myAggr = (OnlineLDAPredictAggregationBuffer) agg; + myAggr.reset(); + myAggr.setOptions(topic, alpha, delta); + } + + @Override + public void iterate(@SuppressWarnings("deprecation") AggregationBuffer agg, + Object[] parameters) throws HiveException { + OnlineLDAPredictAggregationBuffer myAggr = (OnlineLDAPredictAggregationBuffer) agg; + + if (parameters[0] == null || parameters[1] == null || parameters[2] == null + || parameters[3] == null) { + return; + } + + String word = PrimitiveObjectInspectorUtils.getString(parameters[0], wordOI); + float value = HiveUtils.getFloat(parameters[1], valueOI); + int label = PrimitiveObjectInspectorUtils.getInt(parameters[2], labelOI); + float lambda = HiveUtils.getFloat(parameters[3], lambdaOI); + + myAggr.iterate(word, value, label, lambda); + } + + @Override + public Object terminatePartial(@SuppressWarnings("deprecation") AggregationBuffer agg) + throws HiveException { + OnlineLDAPredictAggregationBuffer myAggr = (OnlineLDAPredictAggregationBuffer) agg; + if (myAggr.wcList.size() == 0) { + return null; + } + + Object[] partialResult = new Object[5]; + partialResult[0] = myAggr.wcList; + partialResult[1] = myAggr.lambdaMap; + partialResult[2] = new IntWritable(myAggr.topic); + partialResult[3] = new FloatWritable(myAggr.alpha); + partialResult[4] = new DoubleWritable(myAggr.delta); + + return partialResult; + } + + @Override + public void merge(@SuppressWarnings("deprecation") AggregationBuffer agg, Object partial) + throws HiveException { + if (partial == null) { + return; + } + + Object wcListObj = internalMergeOI.getStructFieldData(partial, wcListField); + + List<?> wcListRaw = wcListOI.getList(HiveUtils.castLazyBinaryObject(wcListObj)); + + // fix list elements to Java String objects + int wcListSize = wcListRaw.size(); + List<String> wcList = new ArrayList<String>(); + for (int i = 0; i < wcListSize; i++) { + wcList.add(PrimitiveObjectInspectorUtils.getString(wcListRaw.get(i), wcListElemOI)); + } + + Object lambdaMapObj = internalMergeOI.getStructFieldData(partial, lambdaMapField); + Map<?, ?> lambdaMapRaw = lambdaMapOI.getMap(HiveUtils.castLazyBinaryObject(lambdaMapObj)); + + Map<String, List<Float>> lambdaMap = new HashMap<String, List<Float>>(); + for (Map.Entry<?, ?> e : lambdaMapRaw.entrySet()) { + // fix map keys to Java String objects + String word = PrimitiveObjectInspectorUtils.getString(e.getKey(), lambdaMapKeyOI); + + Object lambdaMapValueObj = e.getValue(); + List<?> lambdaMapValueRaw = lambdaMapValueOI.getList(HiveUtils.castLazyBinaryObject(lambdaMapValueObj)); + + // fix map values to lists of Java Float objects + int lambdaMapValueSize = lambdaMapValueRaw.size(); + List<Float> lambda_word = new ArrayList<Float>(); + for (int i = 0; i < lambdaMapValueSize; i++) { + lambda_word.add(HiveUtils.getFloat(lambdaMapValueRaw.get(i), + lambdaMapValueElemOI)); + } + + lambdaMap.put(word, lambda_word); + } + + // restore options from partial result + Object topicObj = internalMergeOI.getStructFieldData(partial, topicOptionField); + this.topic = PrimitiveObjectInspectorFactory.writableIntObjectInspector.get(topicObj); + + Object alphaObj = internalMergeOI.getStructFieldData(partial, alphaOptionField); + this.alpha = PrimitiveObjectInspectorFactory.writableFloatObjectInspector.get(alphaObj); + + Object deltaObj = internalMergeOI.getStructFieldData(partial, deltaOptionField); + this.delta = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(deltaObj); + + OnlineLDAPredictAggregationBuffer myAggr = (OnlineLDAPredictAggregationBuffer) agg; + myAggr.setOptions(topic, alpha, delta); + myAggr.merge(wcList, lambdaMap); + } + + @Override + public Object terminate(@SuppressWarnings("deprecation") AggregationBuffer agg) + throws HiveException { + OnlineLDAPredictAggregationBuffer myAggr = (OnlineLDAPredictAggregationBuffer) agg; + float[] topicDistr = myAggr.get(); + + SortedMap<Float, Integer> sortedDistr = new TreeMap<Float, Integer>( + Collections.reverseOrder()); + for (int i = 0; i < topicDistr.length; i++) { + sortedDistr.put(topicDistr[i], i); + } + + List<Object[]> result = new ArrayList<Object[]>(); + for (Map.Entry<Float, Integer> e : sortedDistr.entrySet()) { + Object[] struct = new Object[2]; + struct[0] = new IntWritable(e.getValue()); // label + struct[1] = new FloatWritable(e.getKey()); // probability + result.add(struct); + } + return result; + } + + } + + public static class OnlineLDAPredictAggregationBuffer extends + GenericUDAFEvaluator.AbstractAggregationBuffer { + + private List<String> wcList; + private Map<String, List<Float>> lambdaMap; + + private int topic; + private float alpha; + private double delta; + + OnlineLDAPredictAggregationBuffer() { + super(); + } + + void setOptions(int topic, float alpha, double delta) { + this.topic = topic; + this.alpha = alpha; + this.delta = delta; + } + + void reset() { + this.wcList = new ArrayList<String>(); + this.lambdaMap = new HashMap<String, List<Float>>(); + } + + void iterate(String word, float value, int label, float lambda) { + wcList.add(word + ":" + value); + + // for an unforeseen word, initialize its lambdas w/ -1s + if (!lambdaMap.containsKey(word)) { + List<Float> lambdaEmpty_word = new ArrayList<Float>( + Collections.nCopies(topic, -1.f)); + lambdaMap.put(word, lambdaEmpty_word); + } + + // set the given lambda value + List<Float> lambda_word = lambdaMap.get(word); + lambda_word.set(label, lambda); + lambdaMap.put(word, lambda_word); + } + + void merge(List<String> o_wcList, Map<String, List<Float>> o_lambdaMap) { + wcList.addAll(o_wcList); + + for (Map.Entry<String, List<Float>> e : o_lambdaMap.entrySet()) { + String o_word = e.getKey(); + List<Float> o_lambda_word = e.getValue(); + + if (!lambdaMap.containsKey(o_word)) { // for an unforeseen word + lambdaMap.put(o_word, o_lambda_word); + } else { // for a partially observed word + List<Float> lambda_word = lambdaMap.get(o_word); + for (int k = 0; k < topic; k++) { + if (o_lambda_word.get(k) != -1.f) { // not default value + lambda_word.set(k, o_lambda_word.get(k)); // set the partial lambda value + } + } + lambdaMap.put(o_word, lambda_word); + } + } + } + + float[] get() { + OnlineLDAModel model = new OnlineLDAModel(topic, alpha, delta); + + for (String word : lambdaMap.keySet()) { + List<Float> lambda_word = lambdaMap.get(word); + for (int k = 0; k < topic; k++) { + model.setLambda(word, k, lambda_word.get(k)); + } + } + + String[] wcArray = wcList.toArray(new String[wcList.size()]); + return model.getTopicDistribution(wcArray); + } + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9b2ddcc7/core/src/main/java/hivemall/topicmodel/LDAUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/topicmodel/LDAUDTF.java b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java new file mode 100644 index 0000000..91ee7a2 --- /dev/null +++ b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java @@ -0,0 +1,567 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.topicmodel; + +import hivemall.UDTFWithOptions; +import hivemall.annotations.VisibleForTesting; +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.io.FileUtils; +import hivemall.utils.io.NioStatefullSegment; +import hivemall.utils.lang.NumberUtils; +import hivemall.utils.lang.Primitives; +import hivemall.utils.lang.SizeOf; + +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.SortedMap; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Options; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.io.FloatWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapred.Counters; +import org.apache.hadoop.mapred.Reporter; + +@Description(name = "train_lda", value = "_FUNC_(array<string> words[, const string options])" + + " - Returns a relation consists of <int topic, string word, float score>") +public class LDAUDTF extends UDTFWithOptions { + private static final Log logger = LogFactory.getLog(LDAUDTF.class); + + // Options + protected int topic; + protected float alpha; + protected float eta; + protected long numDocs; + protected double tau0; + protected double kappa; + protected int iterations; + protected double delta; + protected double eps; + protected int miniBatchSize; + + // if `num_docs` option is not given, this flag will be true + // in that case, UDTF automatically sets `count` value to the _D parameter in an online LDA model + protected boolean isAutoD; + + // number of proceeded training samples + protected long count; + + protected String[][] miniBatch; + protected int miniBatchCount; + + protected transient OnlineLDAModel model; + + protected ListObjectInspector wordCountsOI; + + // for iterations + protected NioStatefullSegment fileIO; + protected ByteBuffer inputBuf; + + public LDAUDTF() { + this.topic = 10; + this.alpha = 1.f / topic; + this.eta = 1.f / topic; + this.numDocs = -1L; + this.tau0 = 64.d; + this.kappa = 0.7; + this.iterations = 10; + this.delta = 1E-3d; + this.eps = 1E-1d; + this.miniBatchSize = 128; // if 1, truly online setting + } + + @Override + protected Options getOptions() { + Options opts = new Options(); + opts.addOption("k", "topic", true, "The number of topics [default: 10]"); + opts.addOption("alpha", true, "The hyperparameter for theta [default: 1/k]"); + opts.addOption("eta", true, "The hyperparameter for beta [default: 1/k]"); + opts.addOption("d", "num_docs", true, "The total number of documents [default: auto]"); + opts.addOption("tau", "tau0", true, + "The parameter which downweights early iterations [default: 64.0]"); + opts.addOption("kappa", true, "Exponential decay rate (i.e., learning rate) [default: 0.7]"); + opts.addOption("iter", "iterations", true, "The maximum number of iterations [default: 10]"); + opts.addOption("delta", true, "Check convergence in the expectation step [default: 1E-3]"); + opts.addOption("eps", "epsilon", true, + "Check convergence based on the difference of perplexity [default: 1E-1]"); + opts.addOption("s", "mini_batch_size", true, + "Repeat model updating per mini-batch [default: 128]"); + return opts; + } + + @Override + protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException { + CommandLine cl = null; + + if (argOIs.length >= 2) { + String rawArgs = HiveUtils.getConstString(argOIs[1]); + cl = parseOptions(rawArgs); + this.topic = Primitives.parseInt(cl.getOptionValue("topic"), 10); + this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), 1.f / topic); + this.eta = Primitives.parseFloat(cl.getOptionValue("eta"), 1.f / topic); + this.numDocs = Primitives.parseLong(cl.getOptionValue("num_docs"), -1L); + this.tau0 = Primitives.parseDouble(cl.getOptionValue("tau0"), 64.d); + if (tau0 <= 0.d) { + throw new UDFArgumentException("'-tau0' must be positive: " + tau0); + } + this.kappa = Primitives.parseDouble(cl.getOptionValue("kappa"), 0.7d); + if (kappa <= 0.5 || kappa > 1.d) { + throw new UDFArgumentException("'-kappa' must be in (0.5, 1.0]: " + kappa); + } + this.iterations = Primitives.parseInt(cl.getOptionValue("iterations"), 10); + if (iterations < 1) { + throw new UDFArgumentException( + "'-iterations' must be greater than or equals to 1: " + iterations); + } + this.delta = Primitives.parseDouble(cl.getOptionValue("delta"), 1E-3d); + this.eps = Primitives.parseDouble(cl.getOptionValue("epsilon"), 1E-1d); + this.miniBatchSize = Primitives.parseInt(cl.getOptionValue("mini_batch_size"), 128); + } + + return cl; + } + + @Override + public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { + if (argOIs.length < 1) { + throw new UDFArgumentException( + "_FUNC_ takes 1 arguments: array<string> words [, const string options]"); + } + + this.wordCountsOI = HiveUtils.asListOI(argOIs[0]); + HiveUtils.validateFeatureOI(wordCountsOI.getListElementObjectInspector()); + + processOptions(argOIs); + + this.model = null; + this.count = 0L; + this.isAutoD = (numDocs < 0L); + this.miniBatch = new String[miniBatchSize][]; + this.miniBatchCount = 0; + + ArrayList<String> fieldNames = new ArrayList<String>(); + ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); + fieldNames.add("topic"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); + fieldNames.add("word"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector); + fieldNames.add("score"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector); + + return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); + } + + protected void initModel() { + this.model = new OnlineLDAModel(topic, alpha, eta, numDocs, tau0, kappa, delta); + } + + @Override + public void process(Object[] args) throws HiveException { + if (model == null) { + initModel(); + } + + int length = wordCountsOI.getListLength(args[0]); + String[] wordCounts = new String[length]; + int j = 0; + for (int i = 0; i < length; i++) { + Object o = wordCountsOI.getListElement(args[0], i); + if (o == null) { + throw new HiveException("Given feature vector contains invalid elements"); + } + String s = o.toString(); + wordCounts[j] = s; + j++; + } + + count++; + if (isAutoD) { + model.setNumTotalDocs(count); + } + + recordTrainSampleToTempFile(wordCounts); + + miniBatch[miniBatchCount] = wordCounts; + miniBatchCount++; + + if (miniBatchCount == miniBatchSize) { + model.train(miniBatch); + Arrays.fill(miniBatch, null); // clear + miniBatchCount = 0; + } + } + + protected void recordTrainSampleToTempFile(@Nonnull final String[] wordCounts) + throws HiveException { + if (iterations == 1) { + return; + } + + ByteBuffer buf = inputBuf; + NioStatefullSegment dst = fileIO; + + if (buf == null) { + final File file; + try { + file = File.createTempFile("hivemall_lda", ".sgmt"); + file.deleteOnExit(); + if (!file.canWrite()) { + throw new UDFArgumentException("Cannot write a temporary file: " + + file.getAbsolutePath()); + } + logger.info("Record training samples to a file: " + file.getAbsolutePath()); + } catch (IOException ioe) { + throw new UDFArgumentException(ioe); + } catch (Throwable e) { + throw new UDFArgumentException(e); + } + this.inputBuf = buf = ByteBuffer.allocateDirect(1024 * 1024); // 1 MB + this.fileIO = dst = new NioStatefullSegment(file, false); + } + + int wcLength = 0; + for (String wc : wordCounts) { + if (wc == null) { + continue; + } + wcLength += wc.getBytes().length; + } + // recordBytes, wordCounts length, wc1 length, wc1 string, wc2 length, wc2 string, ... + int recordBytes = (Integer.SIZE * 2 + Integer.SIZE * wcLength) / 8 + wcLength; + int remain = buf.remaining(); + if (remain < recordBytes) { + writeBuffer(buf, dst); + } + + buf.putInt(recordBytes); + buf.putInt(wordCounts.length); + for (String wc : wordCounts) { + if (wc == null) { + continue; + } + buf.putInt(wc.length()); + buf.put(wc.getBytes()); + } + } + + private static void writeBuffer(@Nonnull ByteBuffer srcBuf, @Nonnull NioStatefullSegment dst) + throws HiveException { + srcBuf.flip(); + try { + dst.write(srcBuf); + } catch (IOException e) { + throw new HiveException("Exception causes while writing a buffer to file", e); + } + srcBuf.clear(); + } + + @Override + public void close() throws HiveException { + if (count == 0) { + this.model = null; + return; + } + if (miniBatchCount > 0) { // update for remaining samples + model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount)); + } + if (iterations > 1) { + runIterativeTraining(iterations); + } + forwardModel(); + this.model = null; + } + + protected final void runIterativeTraining(@Nonnegative final int iterations) + throws HiveException { + final ByteBuffer buf = this.inputBuf; + final NioStatefullSegment dst = this.fileIO; + assert (buf != null); + assert (dst != null); + final long numTrainingExamples = count; + + final Reporter reporter = getReporter(); + final Counters.Counter iterCounter = (reporter == null) ? null : reporter.getCounter( + "hivemall.lda.OnlineLDA$Counter", "iteration"); + + try { + if (dst.getPosition() == 0L) {// run iterations w/o temporary file + if (buf.position() == 0) { + return; // no training example + } + buf.flip(); + + int iter = 2; + float perplexityPrev = Float.MAX_VALUE; + float perplexity; + int numTrain; + for (; iter <= iterations; iter++) { + perplexity = 0.f; + numTrain = 0; + + reportProgress(reporter); + setCounterValue(iterCounter, iter); + + Arrays.fill(miniBatch, null); // clear + miniBatchCount = 0; + + while (buf.remaining() > 0) { + int recordBytes = buf.getInt(); + assert (recordBytes > 0) : recordBytes; + int wcLength = buf.getInt(); + final String[] wordCounts = new String[wcLength]; + for (int j = 0; j < wcLength; j++) { + int len = buf.getInt(); + byte[] bytes = new byte[len]; + buf.get(bytes); + wordCounts[j] = new String(bytes); + } + + miniBatch[miniBatchCount] = wordCounts; + miniBatchCount++; + + if (miniBatchCount == miniBatchSize) { + model.train(miniBatch); + perplexity += model.computePerplexity(); + numTrain++; + + Arrays.fill(miniBatch, null); // clear + miniBatchCount = 0; + } + } + buf.rewind(); + + // update for remaining samples + if (miniBatchCount > 0) { // update for remaining samples + model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount)); + perplexity += model.computePerplexity(); + numTrain++; + } + + logger.info("Perplexity: " + perplexity + ", Num train: " + numTrain); + perplexity /= numTrain; // mean perplexity over `numTrain` mini-batches + if (Math.abs(perplexityPrev - perplexity) < eps) { + break; + } + perplexityPrev = perplexity; + } + logger.info("Performed " + + Math.min(iter, iterations) + + " iterations of " + + NumberUtils.formatNumber(numTrainingExamples) + + " training examples on memory (thus " + + NumberUtils.formatNumber(numTrainingExamples * Math.min(iter, iterations)) + + " training updates in total) "); + } else {// read training examples in the temporary file and invoke train for each example + + // write training examples in buffer to a temporary file + if (buf.remaining() > 0) { + writeBuffer(buf, dst); + } + try { + dst.flush(); + } catch (IOException e) { + throw new HiveException("Failed to flush a file: " + + dst.getFile().getAbsolutePath(), e); + } + if (logger.isInfoEnabled()) { + File tmpFile = dst.getFile(); + logger.info("Wrote " + numTrainingExamples + + " records to a temporary file for iterative training: " + + tmpFile.getAbsolutePath() + " (" + FileUtils.prettyFileSize(tmpFile) + + ")"); + } + + // run iterations + int iter = 2; + float perplexityPrev = Float.MAX_VALUE; + float perplexity; + int numTrain; + for (; iter <= iterations; iter++) { + perplexity = 0.f; + numTrain = 0; + + Arrays.fill(miniBatch, null); // clear + miniBatchCount = 0; + + setCounterValue(iterCounter, iter); + + buf.clear(); + dst.resetPosition(); + while (true) { + reportProgress(reporter); + // TODO prefetch + // writes training examples to a buffer in the temporary file + final int bytesRead; + try { + bytesRead = dst.read(buf); + } catch (IOException e) { + throw new HiveException("Failed to read a file: " + + dst.getFile().getAbsolutePath(), e); + } + if (bytesRead == 0) { // reached file EOF + break; + } + assert (bytesRead > 0) : bytesRead; + + // reads training examples from a buffer + buf.flip(); + int remain = buf.remaining(); + if (remain < SizeOf.INT) { + throw new HiveException("Illegal file format was detected"); + } + while (remain >= SizeOf.INT) { + int pos = buf.position(); + int recordBytes = buf.getInt(); + remain -= SizeOf.INT; + if (remain < recordBytes) { + buf.position(pos); + break; + } + + int wcLength = buf.getInt(); + final String[] wordCounts = new String[wcLength]; + for (int j = 0; j < wcLength; j++) { + int len = buf.getInt(); + byte[] bytes = new byte[len]; + buf.get(bytes); + wordCounts[j] = new String(bytes); + } + + miniBatch[miniBatchCount] = wordCounts; + miniBatchCount++; + + if (miniBatchCount == miniBatchSize) { + model.train(miniBatch); + perplexity += model.computePerplexity(); + numTrain++; + + Arrays.fill(miniBatch, null); // clear + miniBatchCount = 0; + } + + remain -= recordBytes; + } + buf.compact(); + } + + // update for remaining samples + if (miniBatchCount > 0) { // update for remaining samples + model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount)); + perplexity += model.computePerplexity(); + numTrain++; + } + + logger.info("Perplexity: " + perplexity + ", Num train: " + numTrain); + perplexity /= numTrain; // mean perplexity over `numTrain` mini-batches + if (Math.abs(perplexityPrev - perplexity) < eps) { + break; + } + perplexityPrev = perplexity; + } + logger.info("Performed " + + Math.min(iter, iterations) + + " iterations of " + + NumberUtils.formatNumber(numTrainingExamples) + + " training examples on a secondary storage (thus " + + NumberUtils.formatNumber(numTrainingExamples * Math.min(iter, iterations)) + + " training updates in total)"); + } + } finally { + // delete the temporary file and release resources + try { + dst.close(true); + } catch (IOException e) { + throw new HiveException("Failed to close a file: " + + dst.getFile().getAbsolutePath(), e); + } + this.inputBuf = null; + this.fileIO = null; + } + } + + protected void forwardModel() throws HiveException { + final IntWritable topicIdx = new IntWritable(); + final Text word = new Text(); + final FloatWritable score = new FloatWritable(); + + final Object[] forwardObjs = new Object[3]; + forwardObjs[0] = topicIdx; + forwardObjs[1] = word; + forwardObjs[2] = score; + + for (int k = 0; k < topic; k++) { + topicIdx.set(k); + + final SortedMap<Float, List<String>> topicWords = model.getTopicWords(k); + for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) { + score.set(e.getKey()); + List<String> words = e.getValue(); + for (int i = 0; i < words.size(); i++) { + word.set(words.get(i)); + forward(forwardObjs); + } + } + } + + logger.info("Forwarded topic words each of " + topic + " topics"); + } + + /* + * For testing: + */ + + @VisibleForTesting + double getLambda(String label, int k) { + return model.getLambda(label, k); + } + + @VisibleForTesting + SortedMap<Float, List<String>> getTopicWords(int k) { + return model.getTopicWords(k); + } + + @VisibleForTesting + SortedMap<Float, List<String>> getTopicWords(int k, int topN) { + return model.getTopicWords(k, topN); + } + + @VisibleForTesting + float[] getTopicDistribution(@Nonnull String[] doc) { + return model.getTopicDistribution(doc); + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9b2ddcc7/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java b/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java new file mode 100644 index 0000000..3e7ad10 --- /dev/null +++ b/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java @@ -0,0 +1,554 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.topicmodel; + +import hivemall.annotations.VisibleForTesting; +import hivemall.model.FeatureValue; +import hivemall.utils.lang.ArrayUtils; +import hivemall.utils.math.MathUtils; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.SortedMap; +import java.util.TreeMap; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + +import org.apache.commons.math3.distribution.GammaDistribution; +import org.apache.commons.math3.special.Gamma; + +public final class OnlineLDAModel { + + // number of topics + private final int _K; + + // prior on weight vectors "theta ~ Dir(alpha_)" + private final float _alpha; + + // prior on topics "beta" + private final float _eta; + + // total number of documents + // in the truly online setting, this can be an estimate of the maximum number of documents that could ever seen + private long _D = -1L; + + // defined by (tau0 + updateCount)^(-kappa_) + // controls how much old lambda is forgotten + private double _rhot; + + // positive value which downweights early iterations + @Nonnegative + private final double _tau0; + + // exponential decay rate (i.e., learning rate) which must be in (0.5, 1] to guarantee convergence + private final double _kappa; + + // how many times EM steps are launched; later EM steps do not drastically forget old lambda + private long _updateCount = 0L; + + // random number generator + @Nonnull + private final GammaDistribution _gd; + private static final double SHAPE = 100.d; + private static final double SCALE = 1.d / SHAPE; + + // parameters + @Nonnull + private List<Map<String, float[]>> _phi; + private float[][] _gamma; + @Nonnull + private final Map<String, float[]> _lambda; + + // check convergence in the expectation (E) step + private final double _delta; + + @Nonnull + private final List<Map<String, Float>> _miniBatchMap; + private int _miniBatchSize; + + // for computing perplexity + private float _docRatio = 1.f; + private long _wordCount = 0L; + + public OnlineLDAModel(int K, float alpha, double delta) { // for E step only instantiation + this(K, alpha, 1 / 20.f, -1L, 1020, 0.7, delta); + } + + public OnlineLDAModel(int K, float alpha, float eta, long D, double tau0, double kappa, + double delta) { + if (tau0 < 0.d) { + throw new IllegalArgumentException("tau0 MUST be positive: " + tau0); + } + if (kappa <= 0.5 || 1.d < kappa) { + throw new IllegalArgumentException("kappa MUST be in (0.5, 1.0]: " + kappa); + } + + this._K = K; + this._alpha = alpha; + this._eta = eta; + this._D = D; + this._tau0 = tau0; + this._kappa = kappa; + this._delta = delta; + + // initialize a random number generator + this._gd = new GammaDistribution(SHAPE, SCALE); + _gd.reseedRandomGenerator(1001); + + // initialize the parameters + this._lambda = new HashMap<String, float[]>(100); + + this._miniBatchMap = new ArrayList<Map<String, Float>>(); + } + + /** + * In a truly online setting, total number of documents corresponds to the number of documents + * that have ever seen. In that case, users need to manually set the current max number of documents + * via this method. + * Note that, since the same set of documents could be repeatedly passed to `train()`, + * simply accumulating `_miniBatchSize`s as estimated `_D` is not sufficient. + */ + public void setNumTotalDocs(@Nonnegative long D) { + this._D = D; + } + + public void train(@Nonnull final String[][] miniBatch) { + if (_D <= 0L) { + throw new RuntimeException("Total number of documents MUST be set via `setNumTotalDocs()`"); + } + + preprocessMiniBatch(miniBatch); + + initParams(true); + + // Expectation + eStep(); + + this._rhot = Math.pow(_tau0 + _updateCount, -_kappa); + + // Maximization + mStep(); + + _updateCount++; + } + + private void preprocessMiniBatch(@Nonnull final String[][] miniBatch) { + initMiniBatchMap(miniBatch, _miniBatchMap); + + this._miniBatchSize = _miniBatchMap.size(); + + // accumulate the number of words for each documents + this._wordCount = 0L; + for (int d = 0; d < _miniBatchSize; d++) { + for (float n : _miniBatchMap.get(d).values()) { + this._wordCount += n; + } + } + + this._docRatio = (float)((double) _D / _miniBatchSize); + } + + private static void initMiniBatchMap(@Nonnull final String[][] miniBatch, + @Nonnull final List<Map<String, Float>> map) { + map.clear(); + + final FeatureValue probe = new FeatureValue(); + + // parse document + for (final String[] e : miniBatch) { + if (e == null) { + continue; + } + + final Map<String, Float> docMap = new HashMap<String, Float>(); + + // parse features + for (String fv : e) { + if (fv == null) { + continue; + } + FeatureValue.parseFeatureAsString(fv, probe); + String label = probe.getFeatureAsString(); + float value = probe.getValueAsFloat(); + docMap.put(label, value); + } + + map.add(docMap); + } + } + + private void initParams(boolean gammaWithRandom) { + _phi = new ArrayList<Map<String, float[]>>(); + _gamma = new float[_miniBatchSize][]; + + for (int d = 0; d < _miniBatchSize; d++) { + if (gammaWithRandom) { + _gamma[d] = ArrayUtils.newRandomFloatArray(_K, _gd); + } else { + _gamma[d] = ArrayUtils.newInstance(_K, 1.f); + } + + final Map<String, float[]> phi_d = new HashMap<String, float[]>(); + _phi.add(phi_d); + for (String label : _miniBatchMap.get(d).keySet()) { + phi_d.put(label, new float[_K]); + if (!_lambda.containsKey(label)) { // lambda for newly observed word + _lambda.put(label, ArrayUtils.newRandomFloatArray(_K, _gd)); + } + } + } + } + + private void eStep() { + // since lambda is invariant in the expectation step, + // `digamma`s of lambda values for Elogbeta are pre-computed + final float[] lambdaSum = new float[_K]; + final Map<String, float[]> digamma_lambda = new HashMap<String, float[]>(); + for (Map.Entry<String, float[]> e : _lambda.entrySet()) { + String label = e.getKey(); + float[] lambda_label = e.getValue(); + + // for digamma(lambdaSum) + MathUtils.add(lambdaSum, lambda_label, _K); + + float[] digamma_lambda_label = new float[_K]; + digamma_lambda.put(label, MathUtils.digamma(lambda_label)); + } + final float[] digamma_lambdaSum = MathUtils.digamma(lambdaSum); + + float[] gamma_d, gammaPrev_d; + Map<String, float[]> eLogBeta_d; + + // for each of mini-batch documents, update gamma until convergence + for (int d = 0; d < _miniBatchSize; d++) { + gamma_d = _gamma[d]; + eLogBeta_d = computeElogBetaPerDoc(d, digamma_lambda, digamma_lambdaSum); + + do { + // (deep) copy the last gamma values + gammaPrev_d = gamma_d.clone(); + + updatePhiPerDoc(d, eLogBeta_d); + updateGammaPerDoc(d); + } while (!checkGammaDiff(gammaPrev_d, gamma_d)); + } + } + + @Nonnull + private Map<String, float[]> computeElogBetaPerDoc(@Nonnegative final int d, + @Nonnull Map<String, float[]> digamma_lambda, @Nonnull float[] digamma_lambdaSum) { + // Dirichlet expectation (2d) for lambda + final Map<String, float[]> eLogBeta_d = new HashMap<String, float[]>(); + final Map<String, Float> doc = _miniBatchMap.get(d); + + for (String label : doc.keySet()) { + float[] eLogBeta_label = eLogBeta_d.get(label); + if (eLogBeta_label == null) { + eLogBeta_label = new float[_K]; + eLogBeta_d.put(label, eLogBeta_label); + } + final float[] digamma_lambda_label = digamma_lambda.get(label); + for (int k = 0; k < _K; k++) { + eLogBeta_label[k] = digamma_lambda_label[k] - digamma_lambdaSum[k]; + } + } + + return eLogBeta_d; + } + + private void updatePhiPerDoc(@Nonnegative final int d, @Nonnull Map<String, float[]> eLogBeta_d) { + // Dirichlet expectation (2d) for gamma + final float[] eLogTheta_d = new float[_K]; + final float[] gamma_d = _gamma[d]; + final float digamma_gammaSum_d = (float) Gamma.digamma(MathUtils.sum(gamma_d)); + for (int k = 0; k < _K; k++) { + eLogTheta_d[k] = (float) Gamma.digamma(gamma_d[k]) - digamma_gammaSum_d; + } + + // updating phi w/ normalization + final Map<String, float[]> phi_d = _phi.get(d); + final Map<String, Float> doc = _miniBatchMap.get(d); + for (String label : doc.keySet()) { + final float[] phi_label = phi_d.get(label); + final float[] eLogBeta_label = eLogBeta_d.get(label); + + float normalizer = 0.f; + for (int k = 0; k < _K; k++) { + float phiVal = (float) Math.exp(eLogBeta_label[k] + eLogTheta_d[k]) + 1E-20f; + phi_label[k] = phiVal; + normalizer += phiVal; + } + + // normalize + for (int k = 0; k < _K; k++) { + phi_label[k] /= normalizer; + } + } + } + + private void updateGammaPerDoc(@Nonnegative final int d) { + final Map<String, Float> doc = _miniBatchMap.get(d); + final Map<String, float[]> phi_d = _phi.get(d); + + final float[] gamma_d = _gamma[d]; + for (int k = 0; k < _K; k++) { + gamma_d[k] = _alpha; + } + for (Map.Entry<String, Float> e : doc.entrySet()) { + final float[] phi_label = phi_d.get(e.getKey()); + final float val = e.getValue(); + for (int k = 0; k < _K; k++) { + gamma_d[k] += phi_label[k] * val; + } + } + } + + private boolean checkGammaDiff(@Nonnull final float[] gammaPrev, + @Nonnull final float[] gammaNext) { + double diff = 0.d; + for (int k = 0; k < _K; k++) { + diff += Math.abs(gammaPrev[k] - gammaNext[k]); + } + return (diff / _K) < _delta; + } + + private void mStep() { + // calculate lambdaTilde for vocabularies in the current mini-batch + final Map<String, float[]> lambdaTilde = new HashMap<String, float[]>(); + for (int d = 0; d < _miniBatchSize; d++) { + final Map<String, float[]> phi_d = _phi.get(d); + for (String label : _miniBatchMap.get(d).keySet()) { + float[] lambdaTilde_label = lambdaTilde.get(label); + if (lambdaTilde_label == null) { + lambdaTilde_label = ArrayUtils.newInstance(_K, _eta); + lambdaTilde.put(label, lambdaTilde_label); + } + + final float[] phi_label = phi_d.get(label); + for (int k = 0; k < _K; k++) { + lambdaTilde_label[k] += _docRatio * phi_label[k]; + } + } + } + + // update lambda for all vocabularies + for (Map.Entry<String, float[]> e : _lambda.entrySet()) { + String label = e.getKey(); + final float[] lambda_label = e.getValue(); + + float[] lambdaTilde_label = lambdaTilde.get(label); + if (lambdaTilde_label == null) { + lambdaTilde_label = ArrayUtils.newInstance(_K, _eta); + } + + for (int k = 0; k < _K; k++) { + lambda_label[k] = (float) ((1.d - _rhot) * lambda_label[k] + _rhot + * lambdaTilde_label[k]); + } + } + } + + /** + * Calculate approximate perplexity for the current mini-batch. + */ + public float computePerplexity() { + float bound = computeApproxBound(); + float perWordBound = bound / (_docRatio * _wordCount); + return (float) Math.exp(-1.f * perWordBound); + } + + /** + * Estimates the variational bound over all documents using only the documents passed as mini-batch. + */ + private float computeApproxBound() { + float score = 0.f; + + // prepare + final float[] gammaSum = new float[_miniBatchSize]; + for (int d = 0; d < _miniBatchSize; d++) { + gammaSum[d] = MathUtils.sum(_gamma[d]); + } + final float[] digamma_gammaSum = MathUtils.digamma(gammaSum); + + final float[] lambdaSum = new float[_K]; + for (float[] lambda_label : _lambda.values()) { + MathUtils.add(lambdaSum, lambda_label, _K); + } + final float[] digamma_lambdaSum = MathUtils.digamma(lambdaSum); + + final float logGamma_alpha = (float) Gamma.logGamma(_alpha); + final float logGamma_alphaSum = (float) Gamma.logGamma(_K * _alpha); + + for (int d = 0; d < _miniBatchSize; d++) { + final float digamma_gammaSum_d = digamma_gammaSum[d]; + + // E[log p(doc | theta, beta)] + for (Map.Entry<String, Float> e : _miniBatchMap.get(d).entrySet()) { + final float[] lambda_label = _lambda.get(e.getKey()); + + // logsumexp( Elogthetad + Elogbetad ) + final float[] temp = new float[_K]; + float max = Float.MIN_VALUE; + for (int k = 0; k < _K; k++) { + final float eLogTheta_dk = (float) Gamma.digamma(_gamma[d][k]) - digamma_gammaSum_d; + final float eLogBeta_kw = (float) Gamma.digamma(lambda_label[k]) - digamma_lambdaSum[k]; + + temp[k] = eLogTheta_dk + eLogBeta_kw; + if (temp[k] > max) { + max = temp[k]; + } + } + float logsumexp = 0.f; + for (int k = 0; k < _K; k++) { + logsumexp += (float) Math.exp(temp[k] - max); + } + logsumexp = max + (float) Math.log(logsumexp); + + // sum( word count * logsumexp(...) ) + score += e.getValue() * logsumexp; + } + + // E[log p(theta | alpha) - log q(theta | gamma)] + for (int k = 0; k < _K; k++) { + final float gamma_dk = _gamma[d][k]; + + // sum( (alpha - gammad) * Elogthetad ) + score += (_alpha - gamma_dk) + * ((float) Gamma.digamma(gamma_dk) - digamma_gammaSum_d); + + // sum( gammaln(gammad) - gammaln(alpha) ) + score += (float) Gamma.logGamma(gamma_dk) - logGamma_alpha; + } + score += logGamma_alphaSum; // gammaln(sum(alpha)) + score -= Gamma.logGamma(gammaSum[d]); // gammaln(sum(gammad)) + } + + // assuming likelihood for when corpus in the documents is only a subset of the whole corpus + // (i.e., online setting); likelihood should be always roughly on the same scale + score *= _docRatio; + + final float logGamma_eta = (float) Gamma.logGamma(_eta); + final float logGamma_etaSum = (float) Gamma.logGamma(_eta * _lambda.size()); // vocabulary size * eta + + // E[log p(beta | eta) - log q (beta | lambda)] + for (float[] lambda_label : _lambda.values()) { + for (int k = 0; k < _K; k++) { + final float lambda_k = lambda_label[k]; + + // sum( (eta - lambda) * Elogbeta ) + score += (_eta - lambda_k) + * (float) (Gamma.digamma(lambda_k) - digamma_lambdaSum[k]); + + // sum( gammaln(lambda) - gammaln(eta) ) + score += (float) Gamma.logGamma(lambda_k) - logGamma_eta; + } + } + for (int k = 0; k < _K; k++) { + // sum( gammaln(etaSum) - gammaln( lambdaSum_k ) + score += logGamma_etaSum - (float) Gamma.logGamma(lambdaSum[k]); + } + + return score; + } + + @VisibleForTesting + double getLambda(@Nonnull final String label, @Nonnegative final int k) { + final float[] lambda_label = _lambda.get(label); + if (lambda_label == null) { + throw new IllegalArgumentException("Word `" + label + "` is not in the corpus."); + } + if (k >= lambda_label.length) { + throw new IllegalArgumentException("Topic index must be in [0, " + + _lambda.get(label).length + "]"); + } + return lambda_label[k]; + } + + public void setLambda(@Nonnull final String label, @Nonnegative final int k, final float lambda_k) { + float[] lambda_label = _lambda.get(label); + if (lambda_label == null) { + lambda_label = ArrayUtils.newRandomFloatArray(_K, _gd); + _lambda.put(label, lambda_label); + } + lambda_label[k] = lambda_k; + } + + @Nonnull + public SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int k) { + return getTopicWords(k, _lambda.keySet().size()); + } + + @Nonnull + public SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int k, + @Nonnegative int topN) { + float lambdaSum = 0.f; + final SortedMap<Float, List<String>> sortedLambda = new TreeMap<Float, List<String>>( + Collections.reverseOrder()); + + for (Map.Entry<String, float[]> e : _lambda.entrySet()) { + final float lambda_k = e.getValue()[k]; + lambdaSum += lambda_k; + + List<String> labels = sortedLambda.get(lambda_k); + if (labels == null) { + labels = new ArrayList<String>(); + sortedLambda.put(lambda_k, labels); + } + labels.add(e.getKey()); + } + + final SortedMap<Float, List<String>> ret = new TreeMap<Float, List<String>>( + Collections.reverseOrder()); + + topN = Math.min(topN, _lambda.keySet().size()); + int tt = 0; + for (Map.Entry<Float, List<String>> e : sortedLambda.entrySet()) { + ret.put(e.getKey() / lambdaSum, e.getValue()); + + if (++tt == topN) { + break; + } + } + + return ret; + } + + @Nonnull + public float[] getTopicDistribution(@Nonnull final String[] doc) { + preprocessMiniBatch(new String[][] {doc}); + + initParams(false); + + eStep(); + + // normalize topic distribution + final float[] topicDistr = new float[_K]; + final float[] gamma0 = _gamma[0]; + final float gammaSum = MathUtils.sum(gamma0); + for (int k = 0; k < _K; k++) { + topicDistr[k] = gamma0[k] / gammaSum; + } + return topicDistr; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9b2ddcc7/core/src/main/java/hivemall/utils/lang/ArrayUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/lang/ArrayUtils.java b/core/src/main/java/hivemall/utils/lang/ArrayUtils.java index e8e337d..711aac7 100644 --- a/core/src/main/java/hivemall/utils/lang/ArrayUtils.java +++ b/core/src/main/java/hivemall/utils/lang/ArrayUtils.java @@ -23,9 +23,12 @@ import java.util.Arrays; import java.util.List; import java.util.Random; +import javax.annotation.Nonnegative; import javax.annotation.Nonnull; import javax.annotation.Nullable; +import org.apache.commons.math3.distribution.GammaDistribution; + public final class ArrayUtils { /** @@ -715,4 +718,21 @@ public final class ArrayUtils { return cnt; } + @Nonnull + public static float[] newInstance(@Nonnegative int size, float filledValue) { + final float[] a = new float[size]; + Arrays.fill(a, filledValue); + return a; + } + + @Nonnull + public static float[] newRandomFloatArray(@Nonnegative final int size, + @Nonnull final GammaDistribution gd) { + final float[] ret = new float[size]; + for (int i = 0; i < size; i++) { + ret[i] = (float) gd.sample(); + } + return ret; + } + } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9b2ddcc7/core/src/main/java/hivemall/utils/math/MathUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/math/MathUtils.java b/core/src/main/java/hivemall/utils/math/MathUtils.java index b71d165..7fdea55 100644 --- a/core/src/main/java/hivemall/utils/math/MathUtils.java +++ b/core/src/main/java/hivemall/utils/math/MathUtils.java @@ -38,6 +38,9 @@ import java.util.Random; import javax.annotation.Nonnegative; import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +import org.apache.commons.math3.special.Gamma; public final class MathUtils { @@ -311,4 +314,44 @@ public final class MathUtils { return perm; } + public static float sum(@Nullable final float[] a) { + if (a == null) { + return 0.f; + } + + float sum = 0.f; + for (float v : a) { + sum += v; + } + return sum; + } + + public static float sum(@Nullable final float[] a, @Nonnegative final int size) { + if (a == null) { + return 0.f; + } + + float sum = 0.f; + for (int i = 0; i < size; i++) { + sum += a[i]; + } + return sum; + } + + public static void add(@Nonnull final float[] dst, @Nonnull final float[] toAdd, final int size) { + for (int i = 0; i < size; i++) { + dst[i] += toAdd[i]; + } + } + + @Nonnull + public static float[] digamma(@Nonnull final float[] a) { + final int k = a.length; + final float[] ret = new float[k]; + for (int i = 0; i < k; i++) { + ret[i] = (float) Gamma.digamma(a[i]); + } + return ret; + } + } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9b2ddcc7/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java b/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java new file mode 100644 index 0000000..a23d917 --- /dev/null +++ b/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.topicmodel; + +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; +import org.apache.hadoop.hive.ql.udf.generic.SimpleGenericUDAFParameterInfo; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; + +import java.util.ArrayList; +import java.util.Map; +import java.util.HashMap; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class LDAPredictUDAFTest { + LDAPredictUDAF udaf; + GenericUDAFEvaluator evaluator; + ObjectInspector[] inputOIs; + ObjectInspector[] partialOI; + LDAPredictUDAF.OnlineLDAPredictAggregationBuffer agg; + + String[] words; + int[] labels; + float[] lambdas; + + @Test(expected=UDFArgumentException.class) + public void testWithoutOption() throws Exception { + udaf = new LDAPredictUDAF(); + + inputOIs = new ObjectInspector[] { + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( + PrimitiveObjectInspector.PrimitiveCategory.STRING), + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( + PrimitiveObjectInspector.PrimitiveCategory.FLOAT), + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( + PrimitiveObjectInspector.PrimitiveCategory.INT), + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( + PrimitiveObjectInspector.PrimitiveCategory.FLOAT)}; + + evaluator = udaf.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false)); + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + } + + @Test(expected=UDFArgumentException.class) + public void testWithoutTopicOption() throws Exception { + udaf = new LDAPredictUDAF(); + + inputOIs = new ObjectInspector[] { + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( + PrimitiveObjectInspector.PrimitiveCategory.STRING), + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( + PrimitiveObjectInspector.PrimitiveCategory.FLOAT), + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( + PrimitiveObjectInspector.PrimitiveCategory.INT), + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( + PrimitiveObjectInspector.PrimitiveCategory.FLOAT), + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-alpha 0.1")}; + + evaluator = udaf.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false)); + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + } + + @Before + public void setUp() throws Exception { + udaf = new LDAPredictUDAF(); + + inputOIs = new ObjectInspector[] { + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( + PrimitiveObjectInspector.PrimitiveCategory.STRING), + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( + PrimitiveObjectInspector.PrimitiveCategory.FLOAT), + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( + PrimitiveObjectInspector.PrimitiveCategory.INT), + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( + PrimitiveObjectInspector.PrimitiveCategory.FLOAT), + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topic 2")}; + + evaluator = udaf.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false)); + + ArrayList<String> fieldNames = new ArrayList<String>(); + ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); + + fieldNames.add("wcList"); + fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector)); + + fieldNames.add("lambdaMap"); + fieldOIs.add(ObjectInspectorFactory.getStandardMapObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.javaFloatObjectInspector))); + + fieldNames.add("topic"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); + + fieldNames.add("alpha"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector); + + fieldNames.add("delta"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); + + partialOI = new ObjectInspector[4]; + partialOI[0] = ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); + + agg = (LDAPredictUDAF.OnlineLDAPredictAggregationBuffer) evaluator.getNewAggregationBuffer(); + + words = new String[] {"fruits", "vegetables", "healthy", "flu", "apples", "oranges", "like", "avocados", "colds", + "colds", "avocados", "oranges", "like", "apples", "flu", "healthy", "vegetables", "fruits"}; + labels = new int[] {0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + lambdas = new float[] {0.3339331f, 0.3324783f, 0.33209667f, 3.2804057E-4f, 3.0303953E-4f, 2.4860457E-4f, 2.41481E-4f, 2.3554532E-4f, 1.352576E-4f, + 0.1660153f, 0.16596903f, 0.1659654f, 0.1659627f, 0.16593699f, 0.1659259f, 0.0017611005f, 0.0015791848f, 8.84464E-4f}; + } + + @Test + public void test() throws Exception { + final Map<String, Float> doc1 = new HashMap<String, Float>(); + doc1.put("fruits", 1.f); + doc1.put("healthy", 1.f); + doc1.put("vegetables", 1.f); + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + for (int i = 0; i < words.length; i++) { + String word = words[i]; + evaluator.iterate(agg, new Object[] {word, doc1.get(word), labels[i], lambdas[i]}); + } + float[] doc1Distr = agg.get(); + + final Map<String, Float> doc2 = new HashMap<String, Float>(); + doc2.put("apples", 1.f); + doc2.put("avocados", 1.f); + doc2.put("colds", 1.f); + doc2.put("flu", 1.f); + doc2.put("like", 2.f); + doc2.put("oranges", 1.f); + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + for (int i = 0; i < words.length; i++) { + String word = words[i]; + evaluator.iterate(agg, new Object[] {word, doc2.get(word), labels[i], lambdas[i]}); + } + float[] doc2Distr = agg.get(); + + Assert.assertTrue(doc1Distr[0] > doc2Distr[0]); + Assert.assertTrue(doc1Distr[1] < doc2Distr[1]); + } + + + @Test + public void testMerge() throws Exception { + final Map<String, Float> doc = new HashMap<String, Float>(); + doc.put("apples", 1.f); + doc.put("avocados", 1.f); + doc.put("colds", 1.f); + doc.put("flu", 1.f); + doc.put("like", 2.f); + doc.put("oranges", 1.f); + + Object[] partials = new Object[3]; + + // bin #1 + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + for (int i = 0; i < 6; i++) { + evaluator.iterate(agg, new Object[]{words[i], doc.get(words[i]), labels[i], lambdas[i]}); + } + partials[0] = evaluator.terminatePartial(agg); + + // bin #2 + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + for (int i = 6; i < 12; i++) { + evaluator.iterate(agg, new Object[]{words[i], doc.get(words[i]), labels[i], lambdas[i]}); + } + partials[1] = evaluator.terminatePartial(agg); + + // bin #3 + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + for (int i = 12; i < 18; i++) { + evaluator.iterate(agg, new Object[]{words[i], doc.get(words[i]), labels[i], lambdas[i]}); + } + + partials[2] = evaluator.terminatePartial(agg); + + // merge in a different order + final int[][] orders = new int[][] {{0, 1, 2}, {1, 0, 2}, {1, 2, 0}, {2, 1, 0}}; + for (int i = 0; i < orders.length; i++) { + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, partialOI); + evaluator.reset(agg); + + evaluator.merge(agg, partials[orders[i][0]]); + evaluator.merge(agg, partials[orders[i][1]]); + evaluator.merge(agg, partials[orders[i][2]]); + + float[] distr = agg.get(); + Assert.assertTrue(distr[0] < distr[1]); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9b2ddcc7/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java b/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java new file mode 100644 index 0000000..d1e3f81 --- /dev/null +++ b/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.topicmodel; + +import java.util.List; +import java.util.Map; +import java.util.SortedMap; +import java.util.Arrays; + +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; + +import org.junit.Assert; +import org.junit.Test; + +public class LDAUDTFTest { + private static final boolean DEBUG = false; + + @Test + public void test() throws HiveException { + LDAUDTF udtf = new LDAUDTF(); + + ObjectInspector[] argOIs = new ObjectInspector[] { + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector), + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topic 2 -num_docs 2 -s 1")}; + + udtf.initialize(argOIs); + + String[] doc1 = new String[]{"fruits:1", "healthy:1", "vegetables:1"}; + String[] doc2 = new String[]{"apples:1", "avocados:1", "colds:1", "flu:1", "like:2", "oranges:1"}; + for (int it = 0; it < 5; it++) { + udtf.process(new Object[]{ Arrays.asList(doc1) }); + udtf.process(new Object[]{ Arrays.asList(doc2) }); + } + + SortedMap<Float, List<String>> topicWords; + + println("Topic 0:"); + println("========"); + topicWords = udtf.getTopicWords(0); + for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) { + List<String> words = e.getValue(); + for (int i = 0; i < words.size(); i++) { + println(e.getKey() + " " + words.get(i)); + } + } + println("========"); + + println("Topic 1:"); + println("========"); + topicWords = udtf.getTopicWords(1); + for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) { + List<String> words = e.getValue(); + for (int i = 0; i < words.size(); i++) { + println(e.getKey() + " " + words.get(i)); + } + } + println("========"); + + int k1, k2; + float[] topicDistr = udtf.getTopicDistribution(doc1); + if (topicDistr[0] > topicDistr[1]) { + // topic 0 MUST represent doc#1 + k1 = 0; + k2 = 1; + } else { + k1 = 1; + k2 = 0; + } + + Assert.assertTrue("doc1 is in topic " + k1 + " (" + (topicDistr[k1] * 100) + "%), " + + "and `vegetables` SHOULD be more suitable topic word than `flu` in the topic", + udtf.getLambda("vegetables", k1) > udtf.getLambda("flu", k1)); + Assert.assertTrue("doc2 is in topic " + k2 + " (" + (topicDistr[k2] * 100) + "%), " + + "and `avocados` SHOULD be more suitable topic word than `healthy` in the topic", + udtf.getLambda("avocados", k2) > udtf.getLambda("healthy", k2)); + } + + private static void println(String msg) { + if (DEBUG) { + System.out.println(msg); + } + } +}