Close #72: [HIVEMALL-86] Updated Hadoop version dependencies from cdh3 to v2.4.0
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/cb16a394 Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/cb16a394 Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/cb16a394 Branch: refs/heads/master Commit: cb16a39440dad5442092d33d6127d71e2ad91d79 Parents: 8aae974 Author: myui <yuin...@gmail.com> Authored: Tue Apr 18 18:46:29 2017 +0900 Committer: myui <yuin...@gmail.com> Committed: Tue Apr 18 18:46:29 2017 +0900 ---------------------------------------------------------------------- bin/spark-shell | 2 +- core/pom.xml | 16 +- .../hivemall/topicmodel/LDAPredictUDAF.java | 476 ++++++++++++++++ .../main/java/hivemall/topicmodel/LDAUDTF.java | 559 +++++++++++++++++++ .../hivemall/topicmodel/OnlineLDAModel.java | 522 +++++++++++++++++ .../hivemall/topicmodel/LDAPredictUDAFTest.java | 228 ++++++++ .../java/hivemall/topicmodel/LDAUDTFTest.java | 197 +++++++ .../hivemall/topicmodel/OnlineLDAModelTest.java | 108 ++++ docs/gitbook/SUMMARY.md | 8 +- docs/gitbook/clustering/lda.md | 170 ++++++ docs/gitbook/getting_started/installation.md | 16 +- mixserv/pom.xml | 16 +- nlp/pom.xml | 16 +- pom.xml | 28 +- resources/ddl/define-all-as-permanent.hive | 10 + resources/ddl/define-all.hive | 10 + resources/ddl/define-all.spark | 10 + resources/ddl/define-udfs.td.hql | 2 + spark/spark-2.0/pom.xml | 2 +- spark/spark-2.1/pom.xml | 2 +- spark/spark-common/pom.xml | 10 +- xgboost/pom.xml | 10 +- 22 files changed, 2376 insertions(+), 42 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/cb16a394/bin/spark-shell ---------------------------------------------------------------------- diff --git a/bin/spark-shell b/bin/spark-shell index 5dcd5d5..199e001 100755 --- a/bin/spark-shell +++ b/bin/spark-shell @@ -71,7 +71,7 @@ install_app() { # install Spark under the bin/ folder if needed. install_spark() { local SPARK_VERSION=`grep "<spark.version>" "${_DIR}/../pom.xml" | head -n1 | awk -F '[<>]' '{print $3}'` - local HADOOP_VERSION=`grep "<hadoop.version>" "${_DIR}/../pom.xml" | head -n1 | awk -F '[<>]' '{print $3}'` + local HADOOP_VERSION=`grep "<hadoop.version>" "${_DIR}/../pom.xml" | head -n1 | awk -F '[<>]' '{print $3}' | cut -d '.' -f1-2` local SPARK_DIR="${_DIR}/spark-${SPARK_VERSION}-bin-hadoop${HADOOP_VERSION}" local APACHE_MIRROR=${APACHE_MIRROR:-'http://d3kbcqa49mib13.cloudfront.net'} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/cb16a394/core/pom.xml ---------------------------------------------------------------------- diff --git a/core/pom.xml b/core/pom.xml index d7655f4..9368993 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -39,8 +39,14 @@ <!-- provided scope --> <dependency> <groupId>org.apache.hadoop</groupId> - <artifactId>hadoop-core</artifactId> - <version>0.20.2-cdh3u6</version> + <artifactId>hadoop-common</artifactId> + <version>${hadoop.version}</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-mapreduce-client-core</artifactId> + <version>${hadoop.version}</version> <scope>provided</scope> </dependency> <dependency> @@ -92,9 +98,9 @@ <scope>provided</scope> </dependency> <dependency> - <groupId>org.apache.hadoop.thirdparty.guava</groupId> + <groupId>com.google.guava</groupId> <artifactId>guava</artifactId> - <version>r09-jarjar</version> + <version>${guava.version}</version> <scope>provided</scope> </dependency> @@ -141,7 +147,7 @@ <dependency> <groupId>junit</groupId> <artifactId>junit</artifactId> - <version>4.12</version> + <version>${junit.version}</version> <scope>test</scope> </dependency> <dependency> http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/cb16a394/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java b/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java new file mode 100644 index 0000000..811af2e --- /dev/null +++ b/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java @@ -0,0 +1,476 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.topicmodel; + +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.lang.CommandLineUtils; +import hivemall.utils.lang.Primitives; + +import java.io.PrintWriter; +import java.io.StringWriter; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.SortedMap; +import java.util.TreeMap; + +import javax.annotation.Nonnull; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.HelpFormatter; +import org.apache.commons.cli.Options; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StandardMapObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructField; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.io.FloatWritable; +import org.apache.hadoop.io.IntWritable; + +@Description(name = "lda_predict", + value = "_FUNC_(string word, float value, int label, float lambda[, const string options])" + + " - Returns a list which consists of <int label, float prob>") +public final class LDAPredictUDAF extends AbstractGenericUDAFResolver { + + @Override + public Evaluator getEvaluator(TypeInfo[] typeInfo) throws SemanticException { + if (typeInfo.length != 4 && typeInfo.length != 5) { + throw new UDFArgumentLengthException( + "Expected argument length is 4 or 5 but given argument length was " + + typeInfo.length); + } + + if (!HiveUtils.isStringTypeInfo(typeInfo[0])) { + throw new UDFArgumentTypeException(0, + "String type is expected for the first argument word: " + typeInfo[0].getTypeName()); + } + if (!HiveUtils.isNumberTypeInfo(typeInfo[1])) { + throw new UDFArgumentTypeException(1, + "Number type is expected for the second argument value: " + + typeInfo[1].getTypeName()); + } + if (!HiveUtils.isIntegerTypeInfo(typeInfo[2])) { + throw new UDFArgumentTypeException(2, + "Integer type is expected for the third argument label: " + + typeInfo[2].getTypeName()); + } + if (!HiveUtils.isNumberTypeInfo(typeInfo[3])) { + throw new UDFArgumentTypeException(3, + "Number type is expected for the forth argument lambda: " + + typeInfo[3].getTypeName()); + } + + if (typeInfo.length == 5) { + if (!HiveUtils.isStringTypeInfo(typeInfo[4])) { + throw new UDFArgumentTypeException(4, + "String type is expected for the fifth argument lambda: " + + typeInfo[4].getTypeName()); + } + } + + return new Evaluator(); + } + + public static class Evaluator extends GenericUDAFEvaluator { + + // input OI + private PrimitiveObjectInspector wordOI; + private PrimitiveObjectInspector valueOI; + private PrimitiveObjectInspector labelOI; + private PrimitiveObjectInspector lambdaOI; + + // Hyperparameters + private int topic; + private float alpha; + private double delta; + + // merge OI + private StructObjectInspector internalMergeOI; + private StructField wcListField; + private StructField lambdaMapField; + private StructField topicOptionField; + private StructField alphaOptionField; + private StructField deltaOptionField; + private PrimitiveObjectInspector wcListElemOI; + private StandardListObjectInspector wcListOI; + private StandardMapObjectInspector lambdaMapOI; + private PrimitiveObjectInspector lambdaMapKeyOI; + private StandardListObjectInspector lambdaMapValueOI; + private PrimitiveObjectInspector lambdaMapValueElemOI; + + public Evaluator() {} + + protected Options getOptions() { + Options opts = new Options(); + opts.addOption("k", "topic", true, "The number of topics [required]"); + opts.addOption("alpha", true, "The hyperparameter for theta [default: 1/k]"); + opts.addOption("delta", true, + "Check convergence in the expectation step [default: 1E-5]"); + return opts; + } + + @Nonnull + protected final CommandLine parseOptions(String optionValue) throws UDFArgumentException { + String[] args = optionValue.split("\\s+"); + Options opts = getOptions(); + opts.addOption("help", false, "Show function help"); + CommandLine cl = CommandLineUtils.parseOptions(args, opts); + + if (cl.hasOption("help")) { + Description funcDesc = getClass().getAnnotation(Description.class); + final String cmdLineSyntax; + if (funcDesc == null) { + cmdLineSyntax = getClass().getSimpleName(); + } else { + String funcName = funcDesc.name(); + cmdLineSyntax = funcName == null ? getClass().getSimpleName() + : funcDesc.value().replace("_FUNC_", funcDesc.name()); + } + StringWriter sw = new StringWriter(); + sw.write('\n'); + PrintWriter pw = new PrintWriter(sw); + HelpFormatter formatter = new HelpFormatter(); + formatter.printHelp(pw, HelpFormatter.DEFAULT_WIDTH, cmdLineSyntax, null, opts, + HelpFormatter.DEFAULT_LEFT_PAD, HelpFormatter.DEFAULT_DESC_PAD, null, true); + pw.flush(); + String helpMsg = sw.toString(); + throw new UDFArgumentException(helpMsg); + } + + return cl; + } + + protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException { + CommandLine cl = null; + + if (argOIs.length != 5) { + throw new UDFArgumentException("At least 1 option `-topic` MUST be specified"); + } + + String rawArgs = HiveUtils.getConstString(argOIs[4]); + cl = parseOptions(rawArgs); + + this.topic = Primitives.parseInt(cl.getOptionValue("topic"), 0); + if (topic < 1) { + throw new UDFArgumentException( + "A positive integer MUST be set to an option `-topic`: " + topic); + } + + this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), 1.f / topic); + this.delta = Primitives.parseDouble(cl.getOptionValue("delta"), 1E-5d); + + return cl; + } + + @Override + public ObjectInspector init(Mode mode, ObjectInspector[] parameters) throws HiveException { + assert (parameters.length == 4 || parameters.length == 5); + super.init(mode, parameters); + + // initialize input + if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data + processOptions(parameters); + this.wordOI = HiveUtils.asStringOI(parameters[0]); + this.valueOI = HiveUtils.asDoubleCompatibleOI(parameters[1]); + this.labelOI = HiveUtils.asIntegerOI(parameters[2]); + this.lambdaOI = HiveUtils.asDoubleCompatibleOI(parameters[3]); + } else {// from partial aggregation + StructObjectInspector soi = (StructObjectInspector) parameters[0]; + this.internalMergeOI = soi; + this.wcListField = soi.getStructFieldRef("wcList"); + this.lambdaMapField = soi.getStructFieldRef("lambdaMap"); + this.topicOptionField = soi.getStructFieldRef("topic"); + this.alphaOptionField = soi.getStructFieldRef("alpha"); + this.deltaOptionField = soi.getStructFieldRef("delta"); + this.wcListElemOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector; + this.wcListOI = ObjectInspectorFactory.getStandardListObjectInspector(wcListElemOI); + this.lambdaMapKeyOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector; + this.lambdaMapValueElemOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector; + this.lambdaMapValueOI = ObjectInspectorFactory.getStandardListObjectInspector(lambdaMapValueElemOI); + this.lambdaMapOI = ObjectInspectorFactory.getStandardMapObjectInspector( + lambdaMapKeyOI, lambdaMapValueOI); + } + + // initialize output + final ObjectInspector outputOI; + if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// terminatePartial + outputOI = internalMergeOI(); + } else { + final ArrayList<String> fieldNames = new ArrayList<String>(); + final ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); + fieldNames.add("label"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); + fieldNames.add("probability"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector); + + outputOI = ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorFactory.getStandardStructObjectInspector( + fieldNames, fieldOIs)); + } + return outputOI; + } + + private static StructObjectInspector internalMergeOI() { + ArrayList<String> fieldNames = new ArrayList<String>(); + ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); + + fieldNames.add("wcList"); + fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector)); + + fieldNames.add("lambdaMap"); + fieldOIs.add(ObjectInspectorFactory.getStandardMapObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaFloatObjectInspector))); + + fieldNames.add("topic"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); + + fieldNames.add("alpha"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector); + + fieldNames.add("delta"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); + + return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); + } + + @SuppressWarnings("deprecation") + @Override + public AggregationBuffer getNewAggregationBuffer() throws HiveException { + AggregationBuffer myAggr = new OnlineLDAPredictAggregationBuffer(); + reset(myAggr); + return myAggr; + } + + @Override + public void reset(@SuppressWarnings("deprecation") AggregationBuffer agg) + throws HiveException { + OnlineLDAPredictAggregationBuffer myAggr = (OnlineLDAPredictAggregationBuffer) agg; + myAggr.reset(); + myAggr.setOptions(topic, alpha, delta); + } + + @Override + public void iterate(@SuppressWarnings("deprecation") AggregationBuffer agg, + Object[] parameters) throws HiveException { + OnlineLDAPredictAggregationBuffer myAggr = (OnlineLDAPredictAggregationBuffer) agg; + + if (parameters[0] == null || parameters[1] == null || parameters[2] == null + || parameters[3] == null) { + return; + } + + String word = PrimitiveObjectInspectorUtils.getString(parameters[0], wordOI); + float value = HiveUtils.getFloat(parameters[1], valueOI); + int label = PrimitiveObjectInspectorUtils.getInt(parameters[2], labelOI); + float lambda = HiveUtils.getFloat(parameters[3], lambdaOI); + + myAggr.iterate(word, value, label, lambda); + } + + @Override + public Object terminatePartial(@SuppressWarnings("deprecation") AggregationBuffer agg) + throws HiveException { + OnlineLDAPredictAggregationBuffer myAggr = (OnlineLDAPredictAggregationBuffer) agg; + if (myAggr.wcList.size() == 0) { + return null; + } + + Object[] partialResult = new Object[5]; + partialResult[0] = myAggr.wcList; + partialResult[1] = myAggr.lambdaMap; + partialResult[2] = new IntWritable(myAggr.topic); + partialResult[3] = new FloatWritable(myAggr.alpha); + partialResult[4] = new DoubleWritable(myAggr.delta); + + return partialResult; + } + + @Override + public void merge(@SuppressWarnings("deprecation") AggregationBuffer agg, Object partial) + throws HiveException { + if (partial == null) { + return; + } + + Object wcListObj = internalMergeOI.getStructFieldData(partial, wcListField); + + List<?> wcListRaw = wcListOI.getList(HiveUtils.castLazyBinaryObject(wcListObj)); + + // fix list elements to Java String objects + int wcListSize = wcListRaw.size(); + List<String> wcList = new ArrayList<String>(); + for (int i = 0; i < wcListSize; i++) { + wcList.add(PrimitiveObjectInspectorUtils.getString(wcListRaw.get(i), wcListElemOI)); + } + + Object lambdaMapObj = internalMergeOI.getStructFieldData(partial, lambdaMapField); + Map<?, ?> lambdaMapRaw = lambdaMapOI.getMap(HiveUtils.castLazyBinaryObject(lambdaMapObj)); + + Map<String, List<Float>> lambdaMap = new HashMap<String, List<Float>>(); + for (Map.Entry<?, ?> e : lambdaMapRaw.entrySet()) { + // fix map keys to Java String objects + String word = PrimitiveObjectInspectorUtils.getString(e.getKey(), lambdaMapKeyOI); + + Object lambdaMapValueObj = e.getValue(); + List<?> lambdaMapValueRaw = lambdaMapValueOI.getList(HiveUtils.castLazyBinaryObject(lambdaMapValueObj)); + + // fix map values to lists of Java Float objects + int lambdaMapValueSize = lambdaMapValueRaw.size(); + List<Float> lambda_word = new ArrayList<Float>(); + for (int i = 0; i < lambdaMapValueSize; i++) { + lambda_word.add(HiveUtils.getFloat(lambdaMapValueRaw.get(i), + lambdaMapValueElemOI)); + } + + lambdaMap.put(word, lambda_word); + } + + // restore options from partial result + Object topicObj = internalMergeOI.getStructFieldData(partial, topicOptionField); + this.topic = PrimitiveObjectInspectorFactory.writableIntObjectInspector.get(topicObj); + + Object alphaObj = internalMergeOI.getStructFieldData(partial, alphaOptionField); + this.alpha = PrimitiveObjectInspectorFactory.writableFloatObjectInspector.get(alphaObj); + + Object deltaObj = internalMergeOI.getStructFieldData(partial, deltaOptionField); + this.delta = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(deltaObj); + + OnlineLDAPredictAggregationBuffer myAggr = (OnlineLDAPredictAggregationBuffer) agg; + myAggr.setOptions(topic, alpha, delta); + myAggr.merge(wcList, lambdaMap); + } + + @Override + public Object terminate(@SuppressWarnings("deprecation") AggregationBuffer agg) + throws HiveException { + OnlineLDAPredictAggregationBuffer myAggr = (OnlineLDAPredictAggregationBuffer) agg; + float[] topicDistr = myAggr.get(); + + SortedMap<Float, Integer> sortedDistr = new TreeMap<Float, Integer>( + Collections.reverseOrder()); + for (int i = 0; i < topicDistr.length; i++) { + sortedDistr.put(topicDistr[i], i); + } + + List<Object[]> result = new ArrayList<Object[]>(); + for (Map.Entry<Float, Integer> e : sortedDistr.entrySet()) { + Object[] struct = new Object[2]; + struct[0] = new IntWritable(e.getValue()); // label + struct[1] = new FloatWritable(e.getKey()); // probability + result.add(struct); + } + return result; + } + + } + + public static class OnlineLDAPredictAggregationBuffer extends + GenericUDAFEvaluator.AbstractAggregationBuffer { + + private List<String> wcList; + private Map<String, List<Float>> lambdaMap; + + private int topic; + private float alpha; + private double delta; + + OnlineLDAPredictAggregationBuffer() { + super(); + } + + void setOptions(int topic, float alpha, double delta) { + this.topic = topic; + this.alpha = alpha; + this.delta = delta; + } + + void reset() { + this.wcList = new ArrayList<String>(); + this.lambdaMap = new HashMap<String, List<Float>>(); + } + + void iterate(String word, float value, int label, float lambda) { + wcList.add(word + ":" + value); + + // for an unforeseen word, initialize its lambdas w/ -1s + if (!lambdaMap.containsKey(word)) { + List<Float> lambdaEmpty_word = new ArrayList<Float>( + Collections.nCopies(topic, -1.f)); + lambdaMap.put(word, lambdaEmpty_word); + } + + // set the given lambda value + List<Float> lambda_word = lambdaMap.get(word); + lambda_word.set(label, lambda); + lambdaMap.put(word, lambda_word); + } + + void merge(List<String> o_wcList, Map<String, List<Float>> o_lambdaMap) { + wcList.addAll(o_wcList); + + for (Map.Entry<String, List<Float>> e : o_lambdaMap.entrySet()) { + String o_word = e.getKey(); + List<Float> o_lambda_word = e.getValue(); + + if (!lambdaMap.containsKey(o_word)) { // for an unforeseen word + lambdaMap.put(o_word, o_lambda_word); + } else { // for a partially observed word + List<Float> lambda_word = lambdaMap.get(o_word); + for (int k = 0; k < topic; k++) { + if (o_lambda_word.get(k) != -1.f) { // not default value + lambda_word.set(k, o_lambda_word.get(k)); // set the partial lambda value + } + } + lambdaMap.put(o_word, lambda_word); + } + } + } + + float[] get() { + OnlineLDAModel model = new OnlineLDAModel(topic, alpha, delta); + + for (String word : lambdaMap.keySet()) { + List<Float> lambda_word = lambdaMap.get(word); + for (int k = 0; k < topic; k++) { + model.setLambda(word, k, lambda_word.get(k)); + } + } + + String[] wcArray = wcList.toArray(new String[wcList.size()]); + return model.getTopicDistribution(wcArray); + } + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/cb16a394/core/src/main/java/hivemall/topicmodel/LDAUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/topicmodel/LDAUDTF.java b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java new file mode 100644 index 0000000..051d7ea --- /dev/null +++ b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java @@ -0,0 +1,559 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.topicmodel; + +import hivemall.UDTFWithOptions; +import hivemall.annotations.VisibleForTesting; +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.io.FileUtils; +import hivemall.utils.io.NioStatefullSegment; +import hivemall.utils.lang.NumberUtils; +import hivemall.utils.lang.Primitives; +import hivemall.utils.lang.SizeOf; + +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.SortedMap; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Options; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.io.FloatWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapred.Counters; +import org.apache.hadoop.mapred.Reporter; + +@Description(name = "train_lda", value = "_FUNC_(array<string> words[, const string options])" + + " - Returns a relation consists of <int topic, string word, float score>") +public class LDAUDTF extends UDTFWithOptions { + private static final Log logger = LogFactory.getLog(LDAUDTF.class); + + // Options + protected int topic; + protected float alpha; + protected float eta; + protected int numDoc; + protected double tau0; + protected double kappa; + protected int iterations; + protected double delta; + protected double eps; + protected int miniBatchSize; + + // if `num_docs` option is not given, this flag will be true + // in that case, UDTF automatically sets `count` value to the _D parameter in an online LDA model + protected boolean isAutoD; + + // number of proceeded training samples + protected long count; + + protected String[][] miniBatch; + protected int miniBatchCount; + + protected OnlineLDAModel model; + + protected ListObjectInspector wordCountsOI; + + // for iterations + protected NioStatefullSegment fileIO; + protected ByteBuffer inputBuf; + + public LDAUDTF() { + this.topic = 10; + this.alpha = 1.f / topic; + this.eta = 1.f / topic; + this.numDoc = -1; + this.tau0 = 64.d; + this.kappa = 0.7; + this.iterations = 1; + this.delta = 1E-5d; + this.eps = 1E-1d; + this.miniBatchSize = 1; // truly online setting + } + + @Override + protected Options getOptions() { + Options opts = new Options(); + opts.addOption("k", "topic", true, "The number of topics [default: 10]"); + opts.addOption("alpha", true, "The hyperparameter for theta [default: 1/k]"); + opts.addOption("eta", true, "The hyperparameter for beta [default: 1/k]"); + opts.addOption("d", "num_docs", true, "The total number of documents [default: auto]"); + opts.addOption("tau", "tau0", true, + "The parameter which downweights early iterations [default: 64.0]"); + opts.addOption("kappa", true, "Exponential decay rate (i.e., learning rate) [default: 0.7]"); + opts.addOption("iter", "iterations", true, "The maximum number of iterations [default: 1]"); + opts.addOption("delta", true, "Check convergence in the expectation step [default: 1E-5]"); + opts.addOption("eps", "epsilon", true, + "Check convergence based on the difference of perplexity [default: 1E-1]"); + opts.addOption("s", "mini_batch_size", true, + "Repeat model updating per mini-batch [default: 1]"); + return opts; + } + + @Override + protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException { + CommandLine cl = null; + + if (argOIs.length >= 2) { + String rawArgs = HiveUtils.getConstString(argOIs[1]); + cl = parseOptions(rawArgs); + this.topic = Primitives.parseInt(cl.getOptionValue("topic"), 10); + this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), 1.f / topic); + this.eta = Primitives.parseFloat(cl.getOptionValue("eta"), 1.f / topic); + this.numDoc = Primitives.parseInt(cl.getOptionValue("num_docs"), -1); + this.tau0 = Primitives.parseDouble(cl.getOptionValue("tau0"), 64.d); + if (tau0 <= 0.d) { + throw new UDFArgumentException("'-tau0' must be positive: " + tau0); + } + this.kappa = Primitives.parseDouble(cl.getOptionValue("kappa"), 0.7d); + if (kappa <= 0.5 || kappa > 1.d) { + throw new UDFArgumentException("'-kappa' must be in (0.5, 1.0]: " + kappa); + } + this.iterations = Primitives.parseInt(cl.getOptionValue("iterations"), 1); + if (iterations < 1) { + throw new UDFArgumentException( + "'-iterations' must be greater than or equals to 1: " + iterations); + } + this.delta = Primitives.parseDouble(cl.getOptionValue("delta"), 1E-5d); + this.eps = Primitives.parseDouble(cl.getOptionValue("epsilon"), 1E-1d); + this.miniBatchSize = Primitives.parseInt(cl.getOptionValue("mini_batch_size"), 1); + } + + return cl; + } + + @Override + public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { + if (argOIs.length < 1) { + throw new UDFArgumentException( + "_FUNC_ takes 1 arguments: array<string> words [, const string options]"); + } + + this.wordCountsOI = HiveUtils.asListOI(argOIs[0]); + HiveUtils.validateFeatureOI(wordCountsOI.getListElementObjectInspector()); + + processOptions(argOIs); + + this.model = new OnlineLDAModel(topic, alpha, eta, numDoc, tau0, kappa, delta); + this.count = 0L; + this.isAutoD = (numDoc < 0); + this.miniBatch = new String[miniBatchSize][]; + this.miniBatchCount = 0; + + ArrayList<String> fieldNames = new ArrayList<String>(); + ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); + fieldNames.add("topic"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); + fieldNames.add("word"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector); + fieldNames.add("score"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector); + + return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); + } + + @Override + public void process(Object[] args) throws HiveException { + int length = wordCountsOI.getListLength(args[0]); + String[] wordCounts = new String[length]; + int j = 0; + for (int i = 0; i < length; i++) { + Object o = wordCountsOI.getListElement(args[0], i); + if (o == null) { + continue; + } + String s = o.toString(); + wordCounts[j] = s; + j++; + } + + count++; + if (isAutoD) { + model.setNumTotalDocs((int) count); + } + + recordTrainSampleToTempFile(wordCounts); + + miniBatch[miniBatchCount] = wordCounts; + miniBatchCount++; + + if (miniBatchCount == miniBatchSize) { + model.train(miniBatch); + Arrays.fill(miniBatch, null); // clear + miniBatchCount = 0; + } + } + + protected void recordTrainSampleToTempFile(@Nonnull final String[] wordCounts) + throws HiveException { + if (iterations == 1) { + return; + } + + ByteBuffer buf = inputBuf; + NioStatefullSegment dst = fileIO; + + if (buf == null) { + final File file; + try { + file = File.createTempFile("hivemall_lda", ".sgmt"); + file.deleteOnExit(); + if (!file.canWrite()) { + throw new UDFArgumentException("Cannot write a temporary file: " + + file.getAbsolutePath()); + } + logger.info("Record training samples to a file: " + file.getAbsolutePath()); + } catch (IOException ioe) { + throw new UDFArgumentException(ioe); + } catch (Throwable e) { + throw new UDFArgumentException(e); + } + this.inputBuf = buf = ByteBuffer.allocateDirect(1024 * 1024); // 1 MB + this.fileIO = dst = new NioStatefullSegment(file, false); + } + + int wcLength = 0; + for (String wc : wordCounts) { + if(wc == null) { + continue; + } + wcLength += wc.getBytes().length; + } + // recordBytes, wordCounts length, wc1 length, wc1 string, wc2 length, wc2 string, ... + int recordBytes = (Integer.SIZE * 2 + Integer.SIZE * wcLength) / 8 + wcLength; + int remain = buf.remaining(); + if (remain < recordBytes) { + writeBuffer(buf, dst); + } + + buf.putInt(recordBytes); + buf.putInt(wordCounts.length); + for (String wc : wordCounts) { + if(wc == null) { + continue; + } + buf.putInt(wc.length()); + buf.put(wc.getBytes()); + } + } + + private static void writeBuffer(@Nonnull ByteBuffer srcBuf, @Nonnull NioStatefullSegment dst) + throws HiveException { + srcBuf.flip(); + try { + dst.write(srcBuf); + } catch (IOException e) { + throw new HiveException("Exception causes while writing a buffer to file", e); + } + srcBuf.clear(); + } + + @Override + public void close() throws HiveException { + if (count == 0) { + this.model = null; + return; + } + if (miniBatchCount > 0) { // update for remaining samples + model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount)); + } + if (iterations > 1) { + runIterativeTraining(iterations); + } + forwardModel(); + this.model = null; + } + + protected final void runIterativeTraining(@Nonnegative final int iterations) + throws HiveException { + final ByteBuffer buf = this.inputBuf; + final NioStatefullSegment dst = this.fileIO; + assert (buf != null); + assert (dst != null); + final long numTrainingExamples = count; + + final Reporter reporter = getReporter(); + final Counters.Counter iterCounter = (reporter == null) ? null : reporter.getCounter( + "hivemall.lda.OnlineLDA$Counter", "iteration"); + + try { + if (dst.getPosition() == 0L) {// run iterations w/o temporary file + if (buf.position() == 0) { + return; // no training example + } + buf.flip(); + + int iter = 2; + float perplexityPrev = Float.MAX_VALUE; + float perplexity; + int numTrain; + for (; iter <= iterations; iter++) { + perplexity = 0.f; + numTrain = 0; + + reportProgress(reporter); + setCounterValue(iterCounter, iter); + + Arrays.fill(miniBatch, null); // clear + miniBatchCount = 0; + + while (buf.remaining() > 0) { + int recordBytes = buf.getInt(); + assert (recordBytes > 0) : recordBytes; + int wcLength = buf.getInt(); + final String[] wordCounts = new String[wcLength]; + for (int j = 0; j < wcLength; j++) { + int len = buf.getInt(); + byte[] bytes = new byte[len]; + buf.get(bytes); + wordCounts[j] = new String(bytes); + } + + miniBatch[miniBatchCount] = wordCounts; + miniBatchCount++; + + if (miniBatchCount == miniBatchSize) { + model.train(miniBatch); + perplexity += model.computePerplexity(); + numTrain++; + + Arrays.fill(miniBatch, null); // clear + miniBatchCount = 0; + } + } + buf.rewind(); + + // update for remaining samples + if (miniBatchCount > 0) { // update for remaining samples + model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount)); + perplexity += model.computePerplexity(); + numTrain++; + } + + logger.info("Perplexity: " + perplexity + ", Num train: " + numTrain); + perplexity /= numTrain; // mean perplexity over `numTrain` mini-batches + if (Math.abs(perplexityPrev - perplexity) < eps) { + break; + } + perplexityPrev = perplexity; + } + logger.info("Performed " + + Math.min(iter, iterations) + + " iterations of " + + NumberUtils.formatNumber(numTrainingExamples) + + " training examples on memory (thus " + + NumberUtils.formatNumber(numTrainingExamples * Math.min(iter, iterations)) + + " training updates in total) "); + } else {// read training examples in the temporary file and invoke train for each example + + // write training examples in buffer to a temporary file + if (buf.remaining() > 0) { + writeBuffer(buf, dst); + } + try { + dst.flush(); + } catch (IOException e) { + throw new HiveException("Failed to flush a file: " + + dst.getFile().getAbsolutePath(), e); + } + if (logger.isInfoEnabled()) { + File tmpFile = dst.getFile(); + logger.info("Wrote " + numTrainingExamples + + " records to a temporary file for iterative training: " + + tmpFile.getAbsolutePath() + " (" + FileUtils.prettyFileSize(tmpFile) + + ")"); + } + + // run iterations + int iter = 2; + float perplexityPrev = Float.MAX_VALUE; + float perplexity; + int numTrain; + for (; iter <= iterations; iter++) { + perplexity = 0.f; + numTrain = 0; + + Arrays.fill(miniBatch, null); // clear + miniBatchCount = 0; + + setCounterValue(iterCounter, iter); + + buf.clear(); + dst.resetPosition(); + while (true) { + reportProgress(reporter); + // TODO prefetch + // writes training examples to a buffer in the temporary file + final int bytesRead; + try { + bytesRead = dst.read(buf); + } catch (IOException e) { + throw new HiveException("Failed to read a file: " + + dst.getFile().getAbsolutePath(), e); + } + if (bytesRead == 0) { // reached file EOF + break; + } + assert (bytesRead > 0) : bytesRead; + + // reads training examples from a buffer + buf.flip(); + int remain = buf.remaining(); + if (remain < SizeOf.INT) { + throw new HiveException("Illegal file format was detected"); + } + while (remain >= SizeOf.INT) { + int pos = buf.position(); + int recordBytes = buf.getInt(); + remain -= SizeOf.INT; + if (remain < recordBytes) { + buf.position(pos); + break; + } + + int wcLength = buf.getInt(); + final String[] wordCounts = new String[wcLength]; + for (int j = 0; j < wcLength; j++) { + int len = buf.getInt(); + byte[] bytes = new byte[len]; + buf.get(bytes); + wordCounts[j] = new String(bytes); + } + + miniBatch[miniBatchCount] = wordCounts; + miniBatchCount++; + + if (miniBatchCount == miniBatchSize) { + model.train(miniBatch); + perplexity += model.computePerplexity(); + numTrain++; + + Arrays.fill(miniBatch, null); // clear + miniBatchCount = 0; + } + + remain -= recordBytes; + } + buf.compact(); + } + + // update for remaining samples + if (miniBatchCount > 0) { // update for remaining samples + model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount)); + perplexity += model.computePerplexity(); + numTrain++; + } + + logger.info("Perplexity: " + perplexity + ", Num train: " + numTrain); + perplexity /= numTrain; // mean perplexity over `numTrain` mini-batches + if (Math.abs(perplexityPrev - perplexity) < eps) { + break; + } + perplexityPrev = perplexity; + } + logger.info("Performed " + + Math.min(iter, iterations) + + " iterations of " + + NumberUtils.formatNumber(numTrainingExamples) + + " training examples on a secondary storage (thus " + + NumberUtils.formatNumber(numTrainingExamples * Math.min(iter, iterations)) + + " training updates in total)"); + } + } finally { + // delete the temporary file and release resources + try { + dst.close(true); + } catch (IOException e) { + throw new HiveException("Failed to close a file: " + + dst.getFile().getAbsolutePath(), e); + } + this.inputBuf = null; + this.fileIO = null; + } + } + + protected void forwardModel() throws HiveException { + final IntWritable topicIdx = new IntWritable(); + final Text word = new Text(); + final FloatWritable score = new FloatWritable(); + + final Object[] forwardObjs = new Object[3]; + forwardObjs[0] = topicIdx; + forwardObjs[1] = word; + forwardObjs[2] = score; + + for (int k = 0; k < topic; k++) { + topicIdx.set(k); + + final SortedMap<Float, List<String>> topicWords = model.getTopicWords(k); + for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) { + score.set(e.getKey()); + List<String> words = e.getValue(); + for (int i = 0; i < words.size(); i++) { + word.set(words.get(i)); + forward(forwardObjs); + } + } + } + + logger.info("Forwarded topic words each of " + topic + " topics"); + } + + /* + * For testing: + */ + + @VisibleForTesting + double getLambda(String label, int k) { + return model.getLambda(label, k); + } + + @VisibleForTesting + SortedMap<Float, List<String>> getTopicWords(int k) { + return model.getTopicWords(k); + } + + @VisibleForTesting + SortedMap<Float, List<String>> getTopicWords(int k, int topN) { + return model.getTopicWords(k, topN); + } + + @VisibleForTesting + float[] getTopicDistribution(@Nonnull String[] doc) { + return model.getTopicDistribution(doc); + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/cb16a394/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java b/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java new file mode 100644 index 0000000..8474212 --- /dev/null +++ b/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java @@ -0,0 +1,522 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.topicmodel; + +import hivemall.annotations.VisibleForTesting; +import hivemall.utils.lang.Preconditions; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.SortedMap; +import java.util.TreeMap; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + +import org.apache.commons.math3.distribution.GammaDistribution; +import org.apache.commons.math3.special.Gamma; + +public final class OnlineLDAModel { + + // number of topics + private int _K; + + // prior on weight vectors "theta ~ Dir(alpha_)" + private float _alpha = 1 / 2.f; + + // prior on topics "beta" + private float _eta = 1 / 20.f; + + // total number of documents + // in the truly online setting, this can be an estimate of the maximum number of documents that could ever seen + private int _D = -1; + + // defined by (tau0 + updateCount)^(-kappa_) + // controls how much old lambda is forgotten + private double _rhot; + + // positive value which downweights early iterations + @Nonnegative + private double _tau0 = 1020; + + // exponential decay rate (i.e., learning rate) which must be in (0.5, 1] to guarantee convergence + private double _kappa = 0.7d; + + // how many times EM steps are launched; later EM steps do not drastically forget old lambda + private long _updateCount = 0L; + + // random number generator + @Nonnull + private final GammaDistribution _gd; + private static double SHAPE = 100.d; + private static double SCALE = 1.d / SHAPE; + + // parameters + @Nonnull + private List<Map<String, float[]>> _phi; + private float[][] _gamma; + @Nonnull + private final Map<String, float[]> _lambda; + + // check convergence in the expectation (E) step + private double _delta = 1e-5; + + @Nonnull + private final List<Map<String, Float>> _miniBatchMap; + private int _miniBatchSize; + + // for computing perplexity + private int _docCount = 0; + private int _wordCount = 0; + + public OnlineLDAModel(int K, float alpha, double delta) { // for E step only instantiation + this(K, alpha, 1 / 20.f, -1, 1020, 0.7, delta); + } + + public OnlineLDAModel(int K, float alpha, float eta, int D, double tau0, double kappa, + double delta) { + Preconditions.checkArgument(0.d < tau0, "tau0 MUST be positive: " + tau0); + Preconditions.checkArgument(0.5 < kappa && kappa <= 1.d, "kappa MUST be in (0.5, 1.0]: " + + kappa); + + _K = K; + _alpha = alpha; + _eta = eta; + _D = D; + _tau0 = tau0; + _kappa = kappa; + _delta = delta; + + // initialize a random number generator + _gd = new GammaDistribution(SHAPE, SCALE); + _gd.reseedRandomGenerator(1001); + + // initialize the parameters + _lambda = new HashMap<String, float[]>(100); + + this._miniBatchMap = new ArrayList<Map<String, Float>>(); + } + + public void setNumTotalDocs(@Nonnegative int D) { + _D = D; + } + + public void train(@Nonnull final String[][] miniBatch) { + Preconditions.checkArgument(_D > 0, + "Total number of documents MUST be set via `setNumTotalDocs()`"); + + _miniBatchSize = miniBatch.length; + + // get the number of words(Nd) for each documents + _wordCount = 0; + for (final String[] e : miniBatch) { + if (e != null) { + _wordCount += e.length; + } + } + _docCount = _miniBatchSize; + + initMiniBatchMap(miniBatch, _miniBatchMap); + + initParams(true); + + _rhot = Math.pow(_tau0 + _updateCount, -_kappa); + + // Expectation + eStep(); + + // Maximization + mStep(); + + _updateCount++; + } + + private static void initMiniBatchMap(@Nonnull final String[][] miniBatch, + @Nonnull final List<Map<String, Float>> map) { + map.clear(); + + // parse document + for (final String[] e : miniBatch) { + if (e == null) { + continue; + } + + final Map<String, Float> docMap = new HashMap<String, Float>(); + + // parse features + for (String fv : e) { + String[] parsedFeature = fv.split(":"); // [`label`, `value`] + if (parsedFeature.length == 1) { // wrong format + continue; + } + String label = parsedFeature[0]; + float value = Float.parseFloat(parsedFeature[1]); + docMap.put(label, value); + } + + map.add(docMap); + } + } + + private void initParams(boolean gammaWithRandom) { + _phi = new ArrayList<Map<String, float[]>>(); + _gamma = new float[_miniBatchSize][]; + + for (int d = 0; d < _miniBatchSize; d++) { + if (gammaWithRandom) { + _gamma[d] = generateRandomFloatArray(_K, _gd); + } else { + _gamma[d] = new float[_K]; + Arrays.fill(_gamma[d], 1.f); + } + + // phi does not needed to be initialized + Map<String, float[]> phi_d = new HashMap<String, float[]>(); + for (String label : _miniBatchMap.get(d).keySet()) { + phi_d.put(label, new float[_K]); + } + _phi.add(phi_d); + + // lambda for newly observed word + for (String label : _miniBatchMap.get(d).keySet()) { + if (!_lambda.containsKey(label)) { + _lambda.put(label, generateRandomFloatArray(_K, _gd)); + } + } + } + } + + @Nonnull + private static float[] generateRandomFloatArray(@Nonnegative final int size, + @Nonnull final GammaDistribution gd) { + final float[] ret = new float[size]; + for (int i = 0; i < size; i++) { + ret[i] = (float) gd.sample(); + } + return ret; + } + + private void eStep() { + float[] gammaPrev_d; + + // for each of mini-batch documents, update gamma until convergence + for (int d = 0; d < _miniBatchSize; d++) { + do { + // (deep) copy the last gamma values + gammaPrev_d = _gamma[d].clone(); + + updatePhiForSingleDoc(d); + updateGammaForSingleDoc(d); + } while (!checkGammaDiff(gammaPrev_d, _gamma[d])); + } + } + + private void updatePhiForSingleDoc(@Nonnegative final int d) { + // dirichlet_expectation_2d(gamma_) + final float[] eLogTheta_d = new float[_K]; + float gammaSum_d = 0.f; + for (int k = 0; k < _K; k++) { + gammaSum_d += _gamma[d][k]; + } + for (int k = 0; k < _K; k++) { + eLogTheta_d[k] = (float) (Gamma.digamma(_gamma[d][k]) - Gamma.digamma(gammaSum_d)); + } + + // dirichlet_expectation_2d(lambda_) + final Map<String, float[]> eLogBeta_d = new HashMap<String, float[]>(); + for (int k = 0; k < _K; k++) { + float lambdaSum_k = 0.f; + for (String label : _lambda.keySet()) { + lambdaSum_k += _lambda.get(label)[k]; + } + for (String label : _miniBatchMap.get(d).keySet()) { + float[] eLogBeta_label; + if (eLogBeta_d.containsKey(label)) { + eLogBeta_label = eLogBeta_d.get(label); + } else { + eLogBeta_label = new float[_K]; + Arrays.fill(eLogBeta_label, 0.f); + } + + eLogBeta_label[k] = (float) (Gamma.digamma(_lambda.get(label)[k]) - Gamma.digamma(lambdaSum_k)); + eLogBeta_d.put(label, eLogBeta_label); + } + } + + // updating phi w/ normalization + for (String label : _miniBatchMap.get(d).keySet()) { + float normalizer = 0.f; + for (int k = 0; k < _K; k++) { + float phi_dwk = (float) Math.exp(eLogTheta_d[k] + eLogBeta_d.get(label)[k]) + 1E-20f; + _phi.get(d).get(label)[k] = phi_dwk; + normalizer += phi_dwk; + } + + // normalize + for (int k = 0; k < _K; k++) { + _phi.get(d).get(label)[k] /= normalizer; + } + } + } + + private void updateGammaForSingleDoc(@Nonnegative final int d) { + for (int k = 0; k < _K; k++) { + float gamma_dk = _alpha; + for (String label : _miniBatchMap.get(d).keySet()) { + gamma_dk += _phi.get(d).get(label)[k] * _miniBatchMap.get(d).get(label); + } + _gamma[d][k] = gamma_dk; + } + } + + private boolean checkGammaDiff(@Nonnull final float[] gammaPrev, + @Nonnull final float[] gammaNext) { + double diff = 0.d; + for (int k = 0; k < _K; k++) { + diff += Math.abs(gammaPrev[k] - gammaNext[k]); + } + return (diff / _K) < _delta; + } + + private void mStep() { + // calculate lambdaNext + final Map<String, float[]> lambdaNext = new HashMap<String, float[]>(); + + float docRatio = (float) _D / (float) _miniBatchSize; + + for (int d = 0; d < _miniBatchSize; d++) { + for (String label : _miniBatchMap.get(d).keySet()) { + float[] lambdaNext_label; + if (lambdaNext.containsKey(label)) { + lambdaNext_label = lambdaNext.get(label); + } else { + lambdaNext_label = new float[_K]; + Arrays.fill(lambdaNext_label, _eta); + } + for (int k = 0; k < _K; k++) { + lambdaNext_label[k] += docRatio * _phi.get(d).get(label)[k]; + } + lambdaNext.put(label, lambdaNext_label); + } + } + + // update lambda_ + for (Map.Entry<String, float[]> e : _lambda.entrySet()) { + String label = e.getKey(); + float[] lambda_label = e.getValue(); + + float[] lambdaNext_label; + if (lambdaNext.containsKey(label)) { + lambdaNext_label = lambdaNext.get(label); + } else { + lambdaNext_label = new float[_K]; + Arrays.fill(lambdaNext_label, _eta); + } + for (int k = 0; k < _K; k++) { + lambda_label[k] = (float) ((1.d - _rhot) * lambda_label[k] + _rhot + * lambdaNext_label[k]); + } + _lambda.put(label, lambda_label); + } + } + + /* + * Methods for debugging and checking convergence: + */ + + /** + * Calculate approximate perplexity for the current mini-batch. + */ + public float computePerplexity() { + float bound = computeApproxBoundForMiniBatch(); + float perWordBound = bound / (float) _wordCount; + return (float) Math.exp(-1.f * perWordBound); + } + + /** + * Estimates the variational bound over all documents using only the documents passed as mini-batch. + */ + private float computeApproxBoundForMiniBatch() { + float score = 0.f; + float tmp; + + // prepare + float[] gammaSum = new float[_miniBatchSize]; + Arrays.fill(gammaSum, 0.f); + for (int d = 0; d < _miniBatchSize; d++) { + for (int k = 0; k < _K; k++) { + gammaSum[d] += _gamma[d][k]; + } + } + float[] lambdaSum = new float[_K]; + Arrays.fill(lambdaSum, 0.f); + for (int k = 0; k < _K; k++) { + for (String label : _lambda.keySet()) { + lambdaSum[k] += _lambda.get(label)[k]; + } + } + + // E[log p(docs | theta, beta)] + for (int d = 0; d < _miniBatchSize; d++) { + + // for each word in the document + for (String label : _miniBatchMap.get(d).keySet()) { + float wordCount = _miniBatchMap.get(d).get(label); + + tmp = 0.f; + for (int k = 0; k < _K; k++) { + float eLogTheta_dk = (float) (Gamma.digamma(_gamma[d][k]) - Gamma.digamma(gammaSum[d])); + float eLogBeta_kw = 0.f; + if (_lambda.containsKey(d)) { + eLogBeta_kw = (float) (Gamma.digamma(_lambda.get(d)[k] + - Gamma.digamma(lambdaSum[k]))); + } + + tmp += _phi.get(d).get(label)[k] + * (eLogTheta_dk + eLogBeta_kw - Math.log(_phi.get(d).get(label)[k])); + } + score += wordCount * tmp; + } + + // E[log p(theta | alpha) - log q(theta | gamma)] + score -= (Gamma.logGamma(gammaSum[d])); + tmp = 0.f; + for (int k = 0; k < _K; k++) { + tmp += (_alpha - _gamma[d][k]) + * (Gamma.digamma(_gamma[d][k]) - Gamma.digamma(gammaSum[d])) + + (Gamma.logGamma(_gamma[d][k])); + tmp /= _docCount; + } + score += tmp; + + // E[log p(beta | eta) - log q (beta | lambda)] + tmp = 0.f; + for (int k = 0; k < _K; k++) { + float tmpPartial = 0.f; + for (String label : _lambda.keySet()) { + tmpPartial += (_eta - _lambda.get(label)[k]) + * (float) (Gamma.digamma(_lambda.get(label)[k]) - Gamma.digamma(lambdaSum[k])) + * (float) (Gamma.logGamma(_lambda.get(label)[k])); + } + + tmp += (-1.f * (float) Gamma.logGamma(lambdaSum[k]) - tmpPartial); + } + score += (tmp / _miniBatchSize); + + float W = _lambda.size(); + tmp = (float) (Gamma.logGamma(_K * _alpha)) + - (float) (_K * (Gamma.logGamma(_alpha))) + + (((float) (Gamma.logGamma(W * _eta)) - (float) (-1.f * W * (Gamma.logGamma(_eta)))) / _docCount); + score += tmp; + } + + return score; + } + + @VisibleForTesting + double getLambda(@Nonnull String label, int k) { + final float[] lambda = _lambda.get(label); + if (lambda == null) { + throw new IllegalArgumentException("Word `" + label + "` is not in the corpus."); + } + if (k >= lambda.length) { + throw new IllegalArgumentException("Topic index must be in [0, " + + _lambda.get(label).length + "]"); + } + return lambda[k]; + } + + public void setLambda(@Nonnull String label, int k, float lambda) { + final float[] lambda_label; + if (!_lambda.containsKey(label)) { + lambda_label = generateRandomFloatArray(_K, _gd); + _lambda.put(label, lambda_label); + } else { + lambda_label = _lambda.get(label); + } + lambda_label[k] = lambda; + } + + @Nonnull + public SortedMap<Float, List<String>> getTopicWords(int k) { + return getTopicWords(k, _lambda.keySet().size()); + } + + @Nonnull + public SortedMap<Float, List<String>> getTopicWords(int k, int topN) { + float lambdaSum = 0.f; + final SortedMap<Float, List<String>> sortedLambda = new TreeMap<Float, List<String>>( + Collections.reverseOrder()); + + for (String label : _lambda.keySet()) { + float lambda = _lambda.get(label)[k]; + lambdaSum += lambda; + + List<String> labels = new ArrayList<String>(); + if (sortedLambda.containsKey(lambda)) { + labels = sortedLambda.get(lambda); + } + labels.add(label); + + sortedLambda.put(lambda, labels); + } + + final SortedMap<Float, List<String>> ret = new TreeMap<Float, List<String>>( + Collections.reverseOrder()); + + topN = Math.min(topN, _lambda.keySet().size()); + int tt = 0; + for (Map.Entry<Float, List<String>> e : sortedLambda.entrySet()) { + float lambda = e.getKey(); + List<String> labels = e.getValue(); + ret.put(lambda / lambdaSum, labels); + + if (++tt == topN) { + break; + } + } + + return ret; + } + + @Nonnull + public float[] getTopicDistribution(@Nonnull final String[] doc) { + _miniBatchSize = 1; + initMiniBatchMap(new String[][] {doc}, _miniBatchMap); + initParams(false); + + eStep(); + + float[] topicDistr = new float[_K]; + // normalize topic distribution + float gammaSum = 0.f; + for (int k = 0; k < _K; k++) { + gammaSum += _gamma[0][k]; + } + for (int k = 0; k < _K; k++) { + topicDistr[k] = _gamma[0][k] / gammaSum; + } + return topicDistr; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/cb16a394/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java b/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java new file mode 100644 index 0000000..a23d917 --- /dev/null +++ b/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.topicmodel; + +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; +import org.apache.hadoop.hive.ql.udf.generic.SimpleGenericUDAFParameterInfo; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; + +import java.util.ArrayList; +import java.util.Map; +import java.util.HashMap; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class LDAPredictUDAFTest { + LDAPredictUDAF udaf; + GenericUDAFEvaluator evaluator; + ObjectInspector[] inputOIs; + ObjectInspector[] partialOI; + LDAPredictUDAF.OnlineLDAPredictAggregationBuffer agg; + + String[] words; + int[] labels; + float[] lambdas; + + @Test(expected=UDFArgumentException.class) + public void testWithoutOption() throws Exception { + udaf = new LDAPredictUDAF(); + + inputOIs = new ObjectInspector[] { + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( + PrimitiveObjectInspector.PrimitiveCategory.STRING), + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( + PrimitiveObjectInspector.PrimitiveCategory.FLOAT), + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( + PrimitiveObjectInspector.PrimitiveCategory.INT), + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( + PrimitiveObjectInspector.PrimitiveCategory.FLOAT)}; + + evaluator = udaf.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false)); + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + } + + @Test(expected=UDFArgumentException.class) + public void testWithoutTopicOption() throws Exception { + udaf = new LDAPredictUDAF(); + + inputOIs = new ObjectInspector[] { + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( + PrimitiveObjectInspector.PrimitiveCategory.STRING), + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( + PrimitiveObjectInspector.PrimitiveCategory.FLOAT), + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( + PrimitiveObjectInspector.PrimitiveCategory.INT), + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( + PrimitiveObjectInspector.PrimitiveCategory.FLOAT), + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-alpha 0.1")}; + + evaluator = udaf.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false)); + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + } + + @Before + public void setUp() throws Exception { + udaf = new LDAPredictUDAF(); + + inputOIs = new ObjectInspector[] { + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( + PrimitiveObjectInspector.PrimitiveCategory.STRING), + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( + PrimitiveObjectInspector.PrimitiveCategory.FLOAT), + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( + PrimitiveObjectInspector.PrimitiveCategory.INT), + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( + PrimitiveObjectInspector.PrimitiveCategory.FLOAT), + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topic 2")}; + + evaluator = udaf.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false)); + + ArrayList<String> fieldNames = new ArrayList<String>(); + ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); + + fieldNames.add("wcList"); + fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector)); + + fieldNames.add("lambdaMap"); + fieldOIs.add(ObjectInspectorFactory.getStandardMapObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.javaFloatObjectInspector))); + + fieldNames.add("topic"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); + + fieldNames.add("alpha"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector); + + fieldNames.add("delta"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); + + partialOI = new ObjectInspector[4]; + partialOI[0] = ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); + + agg = (LDAPredictUDAF.OnlineLDAPredictAggregationBuffer) evaluator.getNewAggregationBuffer(); + + words = new String[] {"fruits", "vegetables", "healthy", "flu", "apples", "oranges", "like", "avocados", "colds", + "colds", "avocados", "oranges", "like", "apples", "flu", "healthy", "vegetables", "fruits"}; + labels = new int[] {0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + lambdas = new float[] {0.3339331f, 0.3324783f, 0.33209667f, 3.2804057E-4f, 3.0303953E-4f, 2.4860457E-4f, 2.41481E-4f, 2.3554532E-4f, 1.352576E-4f, + 0.1660153f, 0.16596903f, 0.1659654f, 0.1659627f, 0.16593699f, 0.1659259f, 0.0017611005f, 0.0015791848f, 8.84464E-4f}; + } + + @Test + public void test() throws Exception { + final Map<String, Float> doc1 = new HashMap<String, Float>(); + doc1.put("fruits", 1.f); + doc1.put("healthy", 1.f); + doc1.put("vegetables", 1.f); + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + for (int i = 0; i < words.length; i++) { + String word = words[i]; + evaluator.iterate(agg, new Object[] {word, doc1.get(word), labels[i], lambdas[i]}); + } + float[] doc1Distr = agg.get(); + + final Map<String, Float> doc2 = new HashMap<String, Float>(); + doc2.put("apples", 1.f); + doc2.put("avocados", 1.f); + doc2.put("colds", 1.f); + doc2.put("flu", 1.f); + doc2.put("like", 2.f); + doc2.put("oranges", 1.f); + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + for (int i = 0; i < words.length; i++) { + String word = words[i]; + evaluator.iterate(agg, new Object[] {word, doc2.get(word), labels[i], lambdas[i]}); + } + float[] doc2Distr = agg.get(); + + Assert.assertTrue(doc1Distr[0] > doc2Distr[0]); + Assert.assertTrue(doc1Distr[1] < doc2Distr[1]); + } + + + @Test + public void testMerge() throws Exception { + final Map<String, Float> doc = new HashMap<String, Float>(); + doc.put("apples", 1.f); + doc.put("avocados", 1.f); + doc.put("colds", 1.f); + doc.put("flu", 1.f); + doc.put("like", 2.f); + doc.put("oranges", 1.f); + + Object[] partials = new Object[3]; + + // bin #1 + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + for (int i = 0; i < 6; i++) { + evaluator.iterate(agg, new Object[]{words[i], doc.get(words[i]), labels[i], lambdas[i]}); + } + partials[0] = evaluator.terminatePartial(agg); + + // bin #2 + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + for (int i = 6; i < 12; i++) { + evaluator.iterate(agg, new Object[]{words[i], doc.get(words[i]), labels[i], lambdas[i]}); + } + partials[1] = evaluator.terminatePartial(agg); + + // bin #3 + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + for (int i = 12; i < 18; i++) { + evaluator.iterate(agg, new Object[]{words[i], doc.get(words[i]), labels[i], lambdas[i]}); + } + + partials[2] = evaluator.terminatePartial(agg); + + // merge in a different order + final int[][] orders = new int[][] {{0, 1, 2}, {1, 0, 2}, {1, 2, 0}, {2, 1, 0}}; + for (int i = 0; i < orders.length; i++) { + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, partialOI); + evaluator.reset(agg); + + evaluator.merge(agg, partials[orders[i][0]]); + evaluator.merge(agg, partials[orders[i][1]]); + evaluator.merge(agg, partials[orders[i][2]]); + + float[] distr = agg.get(); + Assert.assertTrue(distr[0] < distr[1]); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/cb16a394/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java b/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java new file mode 100644 index 0000000..7bd9139 --- /dev/null +++ b/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.topicmodel; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.util.List; +import java.util.ArrayList; +import java.util.Map; +import java.util.SortedMap; +import java.util.Arrays; +import java.util.StringTokenizer; +import java.util.zip.GZIPInputStream; +import java.text.ParseException; + +import hivemall.classifier.KernelExpansionPassiveAggressiveUDTFTest; + +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; + +import org.junit.Assert; +import org.junit.Test; + +import javax.annotation.Nonnull; + +public class LDAUDTFTest { + private static final boolean DEBUG = false; + + @Test + public void test() throws HiveException { + LDAUDTF udtf = new LDAUDTF(); + + ObjectInspector[] argOIs = new ObjectInspector[] { + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector), + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topic 2 -num_docs 2")}; + + udtf.initialize(argOIs); + + String[] doc1 = new String[]{"fruits:1", "healthy:1", "vegetables:1"}; + String[] doc2 = new String[]{"apples:1", "avocados:1", "colds:1", "flu:1", "like:2", "oranges:1"}; + for (int it = 0; it < 5; it++) { + udtf.process(new Object[]{ Arrays.asList(doc1) }); + udtf.process(new Object[]{ Arrays.asList(doc2) }); + } + + SortedMap<Float, List<String>> topicWords; + + println("Topic 0:"); + println("========"); + topicWords = udtf.getTopicWords(0); + for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) { + List<String> words = e.getValue(); + for (int i = 0; i < words.size(); i++) { + println(e.getKey() + " " + words.get(i)); + } + } + println("========"); + + println("Topic 1:"); + println("========"); + topicWords = udtf.getTopicWords(1); + for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) { + List<String> words = e.getValue(); + for (int i = 0; i < words.size(); i++) { + println(e.getKey() + " " + words.get(i)); + } + } + println("========"); + + int k1, k2; + float[] topicDistr = udtf.getTopicDistribution(doc1); + if (topicDistr[0] > topicDistr[1]) { + // topic 0 MUST represent doc#1 + k1 = 0; + k2 = 1; + } else { + k1 = 1; + k2 = 0; + } + + Assert.assertTrue("doc1 is in topic " + k1 + " (" + (topicDistr[k1] * 100) + "%), " + + "and `vegetables` SHOULD be more suitable topic word than `flu` in the topic", + udtf.getLambda("vegetables", k1) > udtf.getLambda("flu", k1)); + Assert.assertTrue("doc2 is in topic " + k2 + " (" + (topicDistr[k2] * 100) + "%), " + + "and `avocados` SHOULD be more suitable topic word than `healthy` in the topic", + udtf.getLambda("avocados", k2) > udtf.getLambda("healthy", k2)); + } + + @Test + public void testNews20() throws IOException, ParseException, HiveException { + LDAUDTF udtf = new LDAUDTF(); + + ObjectInspector[] argOIs = new ObjectInspector[] { + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector), + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topic 20 -delta 0.1 -num_docs 100")}; + + udtf.initialize(argOIs); + + BufferedReader news20 = readFile("news20-small.binary.gz"); + + List<String> doc = new ArrayList<String>(); + + String[] docInClass1 = new String[0]; + String[] docInClass2 = new String[0]; + + String line = news20.readLine(); + while (line != null) { + StringTokenizer tokens = new StringTokenizer(line, " "); + int label = Integer.parseInt(tokens.nextToken()); + + while (tokens.hasMoreTokens()) { + doc.add(tokens.nextToken()); + } + + udtf.process(new Object[]{ doc }); + + if (docInClass1.length == 0 && label == 1) { // store first +1 document + docInClass1 = doc.toArray(new String[doc.size()]); + } else if (docInClass2.length == 0 && label == -1) { // store first -1 document + docInClass2 = doc.toArray(new String[doc.size()]); + } + + doc.clear(); + line = news20.readLine(); + } + + SortedMap<Float, List<String>> topicWords; + + for (int k = 0; k < 20; k++) { + println("========"); + println("Topic " + k); + topicWords = udtf.getTopicWords(k, 5); + for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) { + List<String> words = e.getValue(); + for (int i = 0; i < words.size(); i++) { + println(e.getKey() + " " + words.get(i)); + } + } + println("========"); + } + + int k1 = findMaxTopic(udtf.getTopicDistribution(docInClass1)); + int k2 = findMaxTopic(udtf.getTopicDistribution(docInClass2)); + Assert.assertTrue("Two documents which are respectively in class#1 (+1) and #2 (-1) are assigned to the same topic: " + + k1 + ". Documents in the different class SHOULD be assigned to the different topics.", k1 != k2); + } + + private static void println(String msg) { + if (DEBUG) { + System.out.println(msg); + } + } + + @Nonnull + private static BufferedReader readFile(@Nonnull String fileName) throws IOException { + // use data stored for KPA UDTF test + InputStream is = KernelExpansionPassiveAggressiveUDTFTest.class.getResourceAsStream(fileName); + if (fileName.endsWith(".gz")) { + is = new GZIPInputStream(is); + } + return new BufferedReader(new InputStreamReader(is)); + } + + @Nonnull + private static int findMaxTopic(@Nonnull float[] topicDistr) { + int maxIdx = 0; + for (int i = 1; i < topicDistr.length; i++) { + if (topicDistr[maxIdx] < topicDistr[i]) { + maxIdx = i; + } + } + return maxIdx; + } +}