Close #66: [HIVEMALL-91] Implement Online LDA

Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/9b2ddcc7
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/9b2ddcc7
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/9b2ddcc7

Branch: refs/heads/master
Commit: 9b2ddcc76b0950124373a30c1dbc56acff664ebf
Parents: bba252a
Author: Takuya Kitazawa <k.tak...@gmail.com>
Authored: Thu Apr 20 16:33:20 2017 +0900
Committer: myui <yuin...@gmail.com>
Committed: Thu Apr 20 16:33:20 2017 +0900

----------------------------------------------------------------------
 .../main/java/hivemall/model/FeatureValue.java  |   4 +
 .../hivemall/topicmodel/LDAPredictUDAF.java     | 476 ++++++++++++++++
 .../main/java/hivemall/topicmodel/LDAUDTF.java  | 567 +++++++++++++++++++
 .../hivemall/topicmodel/OnlineLDAModel.java     | 554 ++++++++++++++++++
 .../java/hivemall/utils/lang/ArrayUtils.java    |  20 +
 .../java/hivemall/utils/math/MathUtils.java     |  43 ++
 .../hivemall/topicmodel/LDAPredictUDAFTest.java | 228 ++++++++
 .../java/hivemall/topicmodel/LDAUDTFTest.java   | 104 ++++
 .../hivemall/topicmodel/OnlineLDAModelTest.java | 252 +++++++++
 docs/gitbook/SUMMARY.md                         |   8 +-
 docs/gitbook/clustering/lda.md                  | 195 +++++++
 resources/ddl/define-all-as-permanent.hive      |  10 +
 resources/ddl/define-all.hive                   |  10 +
 resources/ddl/define-all.spark                  |  10 +
 resources/ddl/define-udfs.td.hql                |   2 +
 15 files changed, 2481 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9b2ddcc7/core/src/main/java/hivemall/model/FeatureValue.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/model/FeatureValue.java 
b/core/src/main/java/hivemall/model/FeatureValue.java
index 39fadaf..11aa8f0 100644
--- a/core/src/main/java/hivemall/model/FeatureValue.java
+++ b/core/src/main/java/hivemall/model/FeatureValue.java
@@ -54,6 +54,10 @@ public final class FeatureValue {
         return ((Integer) feature).intValue();
     }
 
+    public String getFeatureAsString() {
+        return feature.toString();
+    }
+
     public double getValue() {
         return value;
     }

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9b2ddcc7/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java 
b/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java
new file mode 100644
index 0000000..811af2e
--- /dev/null
+++ b/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java
@@ -0,0 +1,476 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.topicmodel;
+
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.CommandLineUtils;
+import hivemall.utils.lang.Primitives;
+
+import java.io.PrintWriter;
+import java.io.StringWriter;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.SortedMap;
+import java.util.TreeMap;
+
+import javax.annotation.Nonnull;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.HelpFormatter;
+import org.apache.commons.cli.Options;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.parse.SemanticException;
+import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import 
org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector;
+import 
org.apache.hadoop.hive.serde2.objectinspector.StandardMapObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StructField;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
+import org.apache.hadoop.io.FloatWritable;
+import org.apache.hadoop.io.IntWritable;
+
+@Description(name = "lda_predict",
+        value = "_FUNC_(string word, float value, int label, float lambda[, 
const string options])"
+                + " - Returns a list which consists of <int label, float 
prob>")
+public final class LDAPredictUDAF extends AbstractGenericUDAFResolver {
+
+    @Override
+    public Evaluator getEvaluator(TypeInfo[] typeInfo) throws 
SemanticException {
+        if (typeInfo.length != 4 && typeInfo.length != 5) {
+            throw new UDFArgumentLengthException(
+                "Expected argument length is 4 or 5 but given argument length 
was "
+                        + typeInfo.length);
+        }
+
+        if (!HiveUtils.isStringTypeInfo(typeInfo[0])) {
+            throw new UDFArgumentTypeException(0,
+                "String type is expected for the first argument word: " + 
typeInfo[0].getTypeName());
+        }
+        if (!HiveUtils.isNumberTypeInfo(typeInfo[1])) {
+            throw new UDFArgumentTypeException(1,
+                "Number type is expected for the second argument value: "
+                        + typeInfo[1].getTypeName());
+        }
+        if (!HiveUtils.isIntegerTypeInfo(typeInfo[2])) {
+            throw new UDFArgumentTypeException(2,
+                "Integer type is expected for the third argument label: "
+                        + typeInfo[2].getTypeName());
+        }
+        if (!HiveUtils.isNumberTypeInfo(typeInfo[3])) {
+            throw new UDFArgumentTypeException(3,
+                "Number type is expected for the forth argument lambda: "
+                        + typeInfo[3].getTypeName());
+        }
+
+        if (typeInfo.length == 5) {
+            if (!HiveUtils.isStringTypeInfo(typeInfo[4])) {
+                throw new UDFArgumentTypeException(4,
+                    "String type is expected for the fifth argument lambda: "
+                            + typeInfo[4].getTypeName());
+            }
+        }
+
+        return new Evaluator();
+    }
+
+    public static class Evaluator extends GenericUDAFEvaluator {
+
+        // input OI
+        private PrimitiveObjectInspector wordOI;
+        private PrimitiveObjectInspector valueOI;
+        private PrimitiveObjectInspector labelOI;
+        private PrimitiveObjectInspector lambdaOI;
+
+        // Hyperparameters
+        private int topic;
+        private float alpha;
+        private double delta;
+
+        // merge OI
+        private StructObjectInspector internalMergeOI;
+        private StructField wcListField;
+        private StructField lambdaMapField;
+        private StructField topicOptionField;
+        private StructField alphaOptionField;
+        private StructField deltaOptionField;
+        private PrimitiveObjectInspector wcListElemOI;
+        private StandardListObjectInspector wcListOI;
+        private StandardMapObjectInspector lambdaMapOI;
+        private PrimitiveObjectInspector lambdaMapKeyOI;
+        private StandardListObjectInspector lambdaMapValueOI;
+        private PrimitiveObjectInspector lambdaMapValueElemOI;
+
+        public Evaluator() {}
+
+        protected Options getOptions() {
+            Options opts = new Options();
+            opts.addOption("k", "topic", true, "The number of topics 
[required]");
+            opts.addOption("alpha", true, "The hyperparameter for theta 
[default: 1/k]");
+            opts.addOption("delta", true,
+                "Check convergence in the expectation step [default: 1E-5]");
+            return opts;
+        }
+
+        @Nonnull
+        protected final CommandLine parseOptions(String optionValue) throws 
UDFArgumentException {
+            String[] args = optionValue.split("\\s+");
+            Options opts = getOptions();
+            opts.addOption("help", false, "Show function help");
+            CommandLine cl = CommandLineUtils.parseOptions(args, opts);
+
+            if (cl.hasOption("help")) {
+                Description funcDesc = 
getClass().getAnnotation(Description.class);
+                final String cmdLineSyntax;
+                if (funcDesc == null) {
+                    cmdLineSyntax = getClass().getSimpleName();
+                } else {
+                    String funcName = funcDesc.name();
+                    cmdLineSyntax = funcName == null ? 
getClass().getSimpleName()
+                            : funcDesc.value().replace("_FUNC_", 
funcDesc.name());
+                }
+                StringWriter sw = new StringWriter();
+                sw.write('\n');
+                PrintWriter pw = new PrintWriter(sw);
+                HelpFormatter formatter = new HelpFormatter();
+                formatter.printHelp(pw, HelpFormatter.DEFAULT_WIDTH, 
cmdLineSyntax, null, opts,
+                    HelpFormatter.DEFAULT_LEFT_PAD, 
HelpFormatter.DEFAULT_DESC_PAD, null, true);
+                pw.flush();
+                String helpMsg = sw.toString();
+                throw new UDFArgumentException(helpMsg);
+            }
+
+            return cl;
+        }
+
+        protected CommandLine processOptions(ObjectInspector[] argOIs) throws 
UDFArgumentException {
+            CommandLine cl = null;
+
+            if (argOIs.length != 5) {
+                throw new UDFArgumentException("At least 1 option `-topic` 
MUST be specified");
+            }
+
+            String rawArgs = HiveUtils.getConstString(argOIs[4]);
+            cl = parseOptions(rawArgs);
+
+            this.topic = Primitives.parseInt(cl.getOptionValue("topic"), 0);
+            if (topic < 1) {
+                throw new UDFArgumentException(
+                    "A positive integer MUST be set to an option `-topic`: " + 
topic);
+            }
+
+            this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), 1.f 
/ topic);
+            this.delta = Primitives.parseDouble(cl.getOptionValue("delta"), 
1E-5d);
+
+            return cl;
+        }
+
+        @Override
+        public ObjectInspector init(Mode mode, ObjectInspector[] parameters) 
throws HiveException {
+            assert (parameters.length == 4 || parameters.length == 5);
+            super.init(mode, parameters);
+
+            // initialize input
+            if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from 
original data
+                processOptions(parameters);
+                this.wordOI = HiveUtils.asStringOI(parameters[0]);
+                this.valueOI = HiveUtils.asDoubleCompatibleOI(parameters[1]);
+                this.labelOI = HiveUtils.asIntegerOI(parameters[2]);
+                this.lambdaOI = HiveUtils.asDoubleCompatibleOI(parameters[3]);
+            } else {// from partial aggregation
+                StructObjectInspector soi = (StructObjectInspector) 
parameters[0];
+                this.internalMergeOI = soi;
+                this.wcListField = soi.getStructFieldRef("wcList");
+                this.lambdaMapField = soi.getStructFieldRef("lambdaMap");
+                this.topicOptionField = soi.getStructFieldRef("topic");
+                this.alphaOptionField = soi.getStructFieldRef("alpha");
+                this.deltaOptionField = soi.getStructFieldRef("delta");
+                this.wcListElemOI = 
PrimitiveObjectInspectorFactory.javaStringObjectInspector;
+                this.wcListOI = 
ObjectInspectorFactory.getStandardListObjectInspector(wcListElemOI);
+                this.lambdaMapKeyOI = 
PrimitiveObjectInspectorFactory.javaStringObjectInspector;
+                this.lambdaMapValueElemOI = 
PrimitiveObjectInspectorFactory.javaStringObjectInspector;
+                this.lambdaMapValueOI = 
ObjectInspectorFactory.getStandardListObjectInspector(lambdaMapValueElemOI);
+                this.lambdaMapOI = 
ObjectInspectorFactory.getStandardMapObjectInspector(
+                    lambdaMapKeyOI, lambdaMapValueOI);
+            }
+
+            // initialize output
+            final ObjectInspector outputOI;
+            if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// 
terminatePartial
+                outputOI = internalMergeOI();
+            } else {
+                final ArrayList<String> fieldNames = new ArrayList<String>();
+                final ArrayList<ObjectInspector> fieldOIs = new 
ArrayList<ObjectInspector>();
+                fieldNames.add("label");
+                
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+                fieldNames.add("probability");
+                
fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
+
+                outputOI = 
ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorFactory.getStandardStructObjectInspector(
+                    fieldNames, fieldOIs));
+            }
+            return outputOI;
+        }
+
+        private static StructObjectInspector internalMergeOI() {
+            ArrayList<String> fieldNames = new ArrayList<String>();
+            ArrayList<ObjectInspector> fieldOIs = new 
ArrayList<ObjectInspector>();
+
+            fieldNames.add("wcList");
+            
fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector));
+
+            fieldNames.add("lambdaMap");
+            fieldOIs.add(ObjectInspectorFactory.getStandardMapObjectInspector(
+                PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+                
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaFloatObjectInspector)));
+
+            fieldNames.add("topic");
+            
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+
+            fieldNames.add("alpha");
+            
fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
+
+            fieldNames.add("delta");
+            
fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+
+            return 
ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
+        }
+
+        @SuppressWarnings("deprecation")
+        @Override
+        public AggregationBuffer getNewAggregationBuffer() throws 
HiveException {
+            AggregationBuffer myAggr = new OnlineLDAPredictAggregationBuffer();
+            reset(myAggr);
+            return myAggr;
+        }
+
+        @Override
+        public void reset(@SuppressWarnings("deprecation") AggregationBuffer 
agg)
+                throws HiveException {
+            OnlineLDAPredictAggregationBuffer myAggr = 
(OnlineLDAPredictAggregationBuffer) agg;
+            myAggr.reset();
+            myAggr.setOptions(topic, alpha, delta);
+        }
+
+        @Override
+        public void iterate(@SuppressWarnings("deprecation") AggregationBuffer 
agg,
+                Object[] parameters) throws HiveException {
+            OnlineLDAPredictAggregationBuffer myAggr = 
(OnlineLDAPredictAggregationBuffer) agg;
+
+            if (parameters[0] == null || parameters[1] == null || 
parameters[2] == null
+                    || parameters[3] == null) {
+                return;
+            }
+
+            String word = 
PrimitiveObjectInspectorUtils.getString(parameters[0], wordOI);
+            float value = HiveUtils.getFloat(parameters[1], valueOI);
+            int label = PrimitiveObjectInspectorUtils.getInt(parameters[2], 
labelOI);
+            float lambda = HiveUtils.getFloat(parameters[3], lambdaOI);
+
+            myAggr.iterate(word, value, label, lambda);
+        }
+
+        @Override
+        public Object terminatePartial(@SuppressWarnings("deprecation") 
AggregationBuffer agg)
+                throws HiveException {
+            OnlineLDAPredictAggregationBuffer myAggr = 
(OnlineLDAPredictAggregationBuffer) agg;
+            if (myAggr.wcList.size() == 0) {
+                return null;
+            }
+
+            Object[] partialResult = new Object[5];
+            partialResult[0] = myAggr.wcList;
+            partialResult[1] = myAggr.lambdaMap;
+            partialResult[2] = new IntWritable(myAggr.topic);
+            partialResult[3] = new FloatWritable(myAggr.alpha);
+            partialResult[4] = new DoubleWritable(myAggr.delta);
+
+            return partialResult;
+        }
+
+        @Override
+        public void merge(@SuppressWarnings("deprecation") AggregationBuffer 
agg, Object partial)
+                throws HiveException {
+            if (partial == null) {
+                return;
+            }
+
+            Object wcListObj = internalMergeOI.getStructFieldData(partial, 
wcListField);
+
+            List<?> wcListRaw = 
wcListOI.getList(HiveUtils.castLazyBinaryObject(wcListObj));
+
+            // fix list elements to Java String objects
+            int wcListSize = wcListRaw.size();
+            List<String> wcList = new ArrayList<String>();
+            for (int i = 0; i < wcListSize; i++) {
+                
wcList.add(PrimitiveObjectInspectorUtils.getString(wcListRaw.get(i), 
wcListElemOI));
+            }
+
+            Object lambdaMapObj = internalMergeOI.getStructFieldData(partial, 
lambdaMapField);
+            Map<?, ?> lambdaMapRaw = 
lambdaMapOI.getMap(HiveUtils.castLazyBinaryObject(lambdaMapObj));
+
+            Map<String, List<Float>> lambdaMap = new HashMap<String, 
List<Float>>();
+            for (Map.Entry<?, ?> e : lambdaMapRaw.entrySet()) {
+                // fix map keys to Java String objects
+                String word = 
PrimitiveObjectInspectorUtils.getString(e.getKey(), lambdaMapKeyOI);
+
+                Object lambdaMapValueObj = e.getValue();
+                List<?> lambdaMapValueRaw = 
lambdaMapValueOI.getList(HiveUtils.castLazyBinaryObject(lambdaMapValueObj));
+
+                // fix map values to lists of Java Float objects
+                int lambdaMapValueSize = lambdaMapValueRaw.size();
+                List<Float> lambda_word = new ArrayList<Float>();
+                for (int i = 0; i < lambdaMapValueSize; i++) {
+                    
lambda_word.add(HiveUtils.getFloat(lambdaMapValueRaw.get(i),
+                        lambdaMapValueElemOI));
+                }
+
+                lambdaMap.put(word, lambda_word);
+            }
+
+            // restore options from partial result
+            Object topicObj = internalMergeOI.getStructFieldData(partial, 
topicOptionField);
+            this.topic = 
PrimitiveObjectInspectorFactory.writableIntObjectInspector.get(topicObj);
+
+            Object alphaObj = internalMergeOI.getStructFieldData(partial, 
alphaOptionField);
+            this.alpha = 
PrimitiveObjectInspectorFactory.writableFloatObjectInspector.get(alphaObj);
+
+            Object deltaObj = internalMergeOI.getStructFieldData(partial, 
deltaOptionField);
+            this.delta = 
PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(deltaObj);
+
+            OnlineLDAPredictAggregationBuffer myAggr = 
(OnlineLDAPredictAggregationBuffer) agg;
+            myAggr.setOptions(topic, alpha, delta);
+            myAggr.merge(wcList, lambdaMap);
+        }
+
+        @Override
+        public Object terminate(@SuppressWarnings("deprecation") 
AggregationBuffer agg)
+                throws HiveException {
+            OnlineLDAPredictAggregationBuffer myAggr = 
(OnlineLDAPredictAggregationBuffer) agg;
+            float[] topicDistr = myAggr.get();
+
+            SortedMap<Float, Integer> sortedDistr = new TreeMap<Float, 
Integer>(
+                Collections.reverseOrder());
+            for (int i = 0; i < topicDistr.length; i++) {
+                sortedDistr.put(topicDistr[i], i);
+            }
+
+            List<Object[]> result = new ArrayList<Object[]>();
+            for (Map.Entry<Float, Integer> e : sortedDistr.entrySet()) {
+                Object[] struct = new Object[2];
+                struct[0] = new IntWritable(e.getValue()); // label
+                struct[1] = new FloatWritable(e.getKey()); // probability
+                result.add(struct);
+            }
+            return result;
+        }
+
+    }
+
+    public static class OnlineLDAPredictAggregationBuffer extends
+            GenericUDAFEvaluator.AbstractAggregationBuffer {
+
+        private List<String> wcList;
+        private Map<String, List<Float>> lambdaMap;
+
+        private int topic;
+        private float alpha;
+        private double delta;
+
+        OnlineLDAPredictAggregationBuffer() {
+            super();
+        }
+
+        void setOptions(int topic, float alpha, double delta) {
+            this.topic = topic;
+            this.alpha = alpha;
+            this.delta = delta;
+        }
+
+        void reset() {
+            this.wcList = new ArrayList<String>();
+            this.lambdaMap = new HashMap<String, List<Float>>();
+        }
+
+        void iterate(String word, float value, int label, float lambda) {
+            wcList.add(word + ":" + value);
+
+            // for an unforeseen word, initialize its lambdas w/ -1s
+            if (!lambdaMap.containsKey(word)) {
+                List<Float> lambdaEmpty_word = new ArrayList<Float>(
+                    Collections.nCopies(topic, -1.f));
+                lambdaMap.put(word, lambdaEmpty_word);
+            }
+
+            // set the given lambda value
+            List<Float> lambda_word = lambdaMap.get(word);
+            lambda_word.set(label, lambda);
+            lambdaMap.put(word, lambda_word);
+        }
+
+        void merge(List<String> o_wcList, Map<String, List<Float>> 
o_lambdaMap) {
+            wcList.addAll(o_wcList);
+
+            for (Map.Entry<String, List<Float>> e : o_lambdaMap.entrySet()) {
+                String o_word = e.getKey();
+                List<Float> o_lambda_word = e.getValue();
+
+                if (!lambdaMap.containsKey(o_word)) { // for an unforeseen word
+                    lambdaMap.put(o_word, o_lambda_word);
+                } else { // for a partially observed word
+                    List<Float> lambda_word = lambdaMap.get(o_word);
+                    for (int k = 0; k < topic; k++) {
+                        if (o_lambda_word.get(k) != -1.f) { // not default 
value
+                            lambda_word.set(k, o_lambda_word.get(k)); // set 
the partial lambda value
+                        }
+                    }
+                    lambdaMap.put(o_word, lambda_word);
+                }
+            }
+        }
+
+        float[] get() {
+            OnlineLDAModel model = new OnlineLDAModel(topic, alpha, delta);
+
+            for (String word : lambdaMap.keySet()) {
+                List<Float> lambda_word = lambdaMap.get(word);
+                for (int k = 0; k < topic; k++) {
+                    model.setLambda(word, k, lambda_word.get(k));
+                }
+            }
+
+            String[] wcArray = wcList.toArray(new String[wcList.size()]);
+            return model.getTopicDistribution(wcArray);
+        }
+    }
+
+}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9b2ddcc7/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/LDAUDTF.java 
b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
new file mode 100644
index 0000000..91ee7a2
--- /dev/null
+++ b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
@@ -0,0 +1,567 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.topicmodel;
+
+import hivemall.UDTFWithOptions;
+import hivemall.annotations.VisibleForTesting;
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.io.FileUtils;
+import hivemall.utils.io.NioStatefullSegment;
+import hivemall.utils.lang.NumberUtils;
+import hivemall.utils.lang.Primitives;
+import hivemall.utils.lang.SizeOf;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.SortedMap;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.io.FloatWritable;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapred.Counters;
+import org.apache.hadoop.mapred.Reporter;
+
+@Description(name = "train_lda", value = "_FUNC_(array<string> words[, const 
string options])"
+        + " - Returns a relation consists of <int topic, string word, float 
score>")
+public class LDAUDTF extends UDTFWithOptions {
+    private static final Log logger = LogFactory.getLog(LDAUDTF.class);
+
+    // Options
+    protected int topic;
+    protected float alpha;
+    protected float eta;
+    protected long numDocs;
+    protected double tau0;
+    protected double kappa;
+    protected int iterations;
+    protected double delta;
+    protected double eps;
+    protected int miniBatchSize;
+
+    // if `num_docs` option is not given, this flag will be true
+    // in that case, UDTF automatically sets `count` value to the _D parameter 
in an online LDA model
+    protected boolean isAutoD;
+
+    // number of proceeded training samples
+    protected long count;
+
+    protected String[][] miniBatch;
+    protected int miniBatchCount;
+
+    protected transient OnlineLDAModel model;
+
+    protected ListObjectInspector wordCountsOI;
+
+    // for iterations
+    protected NioStatefullSegment fileIO;
+    protected ByteBuffer inputBuf;
+
+    public LDAUDTF() {
+        this.topic = 10;
+        this.alpha = 1.f / topic;
+        this.eta = 1.f / topic;
+        this.numDocs = -1L;
+        this.tau0 = 64.d;
+        this.kappa = 0.7;
+        this.iterations = 10;
+        this.delta = 1E-3d;
+        this.eps = 1E-1d;
+        this.miniBatchSize = 128; // if 1, truly online setting
+    }
+
+    @Override
+    protected Options getOptions() {
+        Options opts = new Options();
+        opts.addOption("k", "topic", true, "The number of topics [default: 
10]");
+        opts.addOption("alpha", true, "The hyperparameter for theta [default: 
1/k]");
+        opts.addOption("eta", true, "The hyperparameter for beta [default: 
1/k]");
+        opts.addOption("d", "num_docs", true, "The total number of documents 
[default: auto]");
+        opts.addOption("tau", "tau0", true,
+            "The parameter which downweights early iterations [default: 
64.0]");
+        opts.addOption("kappa", true, "Exponential decay rate (i.e., learning 
rate) [default: 0.7]");
+        opts.addOption("iter", "iterations", true, "The maximum number of 
iterations [default: 10]");
+        opts.addOption("delta", true, "Check convergence in the expectation 
step [default: 1E-3]");
+        opts.addOption("eps", "epsilon", true,
+            "Check convergence based on the difference of perplexity [default: 
1E-1]");
+        opts.addOption("s", "mini_batch_size", true,
+            "Repeat model updating per mini-batch [default: 128]");
+        return opts;
+    }
+
+    @Override
+    protected CommandLine processOptions(ObjectInspector[] argOIs) throws 
UDFArgumentException {
+        CommandLine cl = null;
+
+        if (argOIs.length >= 2) {
+            String rawArgs = HiveUtils.getConstString(argOIs[1]);
+            cl = parseOptions(rawArgs);
+            this.topic = Primitives.parseInt(cl.getOptionValue("topic"), 10);
+            this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), 1.f 
/ topic);
+            this.eta = Primitives.parseFloat(cl.getOptionValue("eta"), 1.f / 
topic);
+            this.numDocs = Primitives.parseLong(cl.getOptionValue("num_docs"), 
-1L);
+            this.tau0 = Primitives.parseDouble(cl.getOptionValue("tau0"), 
64.d);
+            if (tau0 <= 0.d) {
+                throw new UDFArgumentException("'-tau0' must be positive: " + 
tau0);
+            }
+            this.kappa = Primitives.parseDouble(cl.getOptionValue("kappa"), 
0.7d);
+            if (kappa <= 0.5 || kappa > 1.d) {
+                throw new UDFArgumentException("'-kappa' must be in (0.5, 
1.0]: " + kappa);
+            }
+            this.iterations = 
Primitives.parseInt(cl.getOptionValue("iterations"), 10);
+            if (iterations < 1) {
+                throw new UDFArgumentException(
+                    "'-iterations' must be greater than or equals to 1: " + 
iterations);
+            }
+            this.delta = Primitives.parseDouble(cl.getOptionValue("delta"), 
1E-3d);
+            this.eps = Primitives.parseDouble(cl.getOptionValue("epsilon"), 
1E-1d);
+            this.miniBatchSize = 
Primitives.parseInt(cl.getOptionValue("mini_batch_size"), 128);
+        }
+
+        return cl;
+    }
+
+    @Override
+    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws 
UDFArgumentException {
+        if (argOIs.length < 1) {
+            throw new UDFArgumentException(
+                "_FUNC_ takes 1 arguments: array<string> words [, const string 
options]");
+        }
+
+        this.wordCountsOI = HiveUtils.asListOI(argOIs[0]);
+        
HiveUtils.validateFeatureOI(wordCountsOI.getListElementObjectInspector());
+
+        processOptions(argOIs);
+
+        this.model = null;
+        this.count = 0L;
+        this.isAutoD = (numDocs < 0L);
+        this.miniBatch = new String[miniBatchSize][];
+        this.miniBatchCount = 0;
+
+        ArrayList<String> fieldNames = new ArrayList<String>();
+        ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
+        fieldNames.add("topic");
+        
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+        fieldNames.add("word");
+        
fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
+        fieldNames.add("score");
+        
fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
+
+        return 
ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
+    }
+
+    protected void initModel() {
+        this.model = new OnlineLDAModel(topic, alpha, eta, numDocs, tau0, 
kappa, delta);
+    }
+
+    @Override
+    public void process(Object[] args) throws HiveException {
+        if (model == null) {
+            initModel();
+        }
+
+        int length = wordCountsOI.getListLength(args[0]);
+        String[] wordCounts = new String[length];
+        int j = 0;
+        for (int i = 0; i < length; i++) {
+            Object o = wordCountsOI.getListElement(args[0], i);
+            if (o == null) {
+                throw new HiveException("Given feature vector contains invalid 
elements");
+            }
+            String s = o.toString();
+            wordCounts[j] = s;
+            j++;
+        }
+
+        count++;
+        if (isAutoD) {
+            model.setNumTotalDocs(count);
+        }
+
+        recordTrainSampleToTempFile(wordCounts);
+
+        miniBatch[miniBatchCount] = wordCounts;
+        miniBatchCount++;
+
+        if (miniBatchCount == miniBatchSize) {
+            model.train(miniBatch);
+            Arrays.fill(miniBatch, null); // clear
+            miniBatchCount = 0;
+        }
+    }
+
+    protected void recordTrainSampleToTempFile(@Nonnull final String[] 
wordCounts)
+            throws HiveException {
+        if (iterations == 1) {
+            return;
+        }
+
+        ByteBuffer buf = inputBuf;
+        NioStatefullSegment dst = fileIO;
+
+        if (buf == null) {
+            final File file;
+            try {
+                file = File.createTempFile("hivemall_lda", ".sgmt");
+                file.deleteOnExit();
+                if (!file.canWrite()) {
+                    throw new UDFArgumentException("Cannot write a temporary 
file: "
+                            + file.getAbsolutePath());
+                }
+                logger.info("Record training samples to a file: " + 
file.getAbsolutePath());
+            } catch (IOException ioe) {
+                throw new UDFArgumentException(ioe);
+            } catch (Throwable e) {
+                throw new UDFArgumentException(e);
+            }
+            this.inputBuf = buf = ByteBuffer.allocateDirect(1024 * 1024); // 1 
MB
+            this.fileIO = dst = new NioStatefullSegment(file, false);
+        }
+
+        int wcLength = 0;
+        for (String wc : wordCounts) {
+            if (wc == null) {
+                continue;
+            }
+            wcLength += wc.getBytes().length;
+        }
+        // recordBytes, wordCounts length, wc1 length, wc1 string, wc2 length, 
wc2 string, ...
+        int recordBytes = (Integer.SIZE * 2 + Integer.SIZE * wcLength) / 8 + 
wcLength;
+        int remain = buf.remaining();
+        if (remain < recordBytes) {
+            writeBuffer(buf, dst);
+        }
+
+        buf.putInt(recordBytes);
+        buf.putInt(wordCounts.length);
+        for (String wc : wordCounts) {
+            if (wc == null) {
+                continue;
+            }
+            buf.putInt(wc.length());
+            buf.put(wc.getBytes());
+        }
+    }
+
+    private static void writeBuffer(@Nonnull ByteBuffer srcBuf, @Nonnull 
NioStatefullSegment dst)
+            throws HiveException {
+        srcBuf.flip();
+        try {
+            dst.write(srcBuf);
+        } catch (IOException e) {
+            throw new HiveException("Exception causes while writing a buffer 
to file", e);
+        }
+        srcBuf.clear();
+    }
+
+    @Override
+    public void close() throws HiveException {
+        if (count == 0) {
+            this.model = null;
+            return;
+        }
+        if (miniBatchCount > 0) { // update for remaining samples
+            model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount));
+        }
+        if (iterations > 1) {
+            runIterativeTraining(iterations);
+        }
+        forwardModel();
+        this.model = null;
+    }
+
+    protected final void runIterativeTraining(@Nonnegative final int 
iterations)
+            throws HiveException {
+        final ByteBuffer buf = this.inputBuf;
+        final NioStatefullSegment dst = this.fileIO;
+        assert (buf != null);
+        assert (dst != null);
+        final long numTrainingExamples = count;
+
+        final Reporter reporter = getReporter();
+        final Counters.Counter iterCounter = (reporter == null) ? null : 
reporter.getCounter(
+            "hivemall.lda.OnlineLDA$Counter", "iteration");
+
+        try {
+            if (dst.getPosition() == 0L) {// run iterations w/o temporary file
+                if (buf.position() == 0) {
+                    return; // no training example
+                }
+                buf.flip();
+
+                int iter = 2;
+                float perplexityPrev = Float.MAX_VALUE;
+                float perplexity;
+                int numTrain;
+                for (; iter <= iterations; iter++) {
+                    perplexity = 0.f;
+                    numTrain = 0;
+
+                    reportProgress(reporter);
+                    setCounterValue(iterCounter, iter);
+
+                    Arrays.fill(miniBatch, null); // clear
+                    miniBatchCount = 0;
+
+                    while (buf.remaining() > 0) {
+                        int recordBytes = buf.getInt();
+                        assert (recordBytes > 0) : recordBytes;
+                        int wcLength = buf.getInt();
+                        final String[] wordCounts = new String[wcLength];
+                        for (int j = 0; j < wcLength; j++) {
+                            int len = buf.getInt();
+                            byte[] bytes = new byte[len];
+                            buf.get(bytes);
+                            wordCounts[j] = new String(bytes);
+                        }
+
+                        miniBatch[miniBatchCount] = wordCounts;
+                        miniBatchCount++;
+
+                        if (miniBatchCount == miniBatchSize) {
+                            model.train(miniBatch);
+                            perplexity += model.computePerplexity();
+                            numTrain++;
+
+                            Arrays.fill(miniBatch, null); // clear
+                            miniBatchCount = 0;
+                        }
+                    }
+                    buf.rewind();
+
+                    // update for remaining samples
+                    if (miniBatchCount > 0) { // update for remaining samples
+                        model.train(Arrays.copyOfRange(miniBatch, 0, 
miniBatchCount));
+                        perplexity += model.computePerplexity();
+                        numTrain++;
+                    }
+
+                    logger.info("Perplexity: " + perplexity + ", Num train: " 
+ numTrain);
+                    perplexity /= numTrain; // mean perplexity over `numTrain` 
mini-batches
+                    if (Math.abs(perplexityPrev - perplexity) < eps) {
+                        break;
+                    }
+                    perplexityPrev = perplexity;
+                }
+                logger.info("Performed "
+                        + Math.min(iter, iterations)
+                        + " iterations of "
+                        + NumberUtils.formatNumber(numTrainingExamples)
+                        + " training examples on memory (thus "
+                        + NumberUtils.formatNumber(numTrainingExamples * 
Math.min(iter, iterations))
+                        + " training updates in total) ");
+            } else {// read training examples in the temporary file and invoke 
train for each example
+
+                // write training examples in buffer to a temporary file
+                if (buf.remaining() > 0) {
+                    writeBuffer(buf, dst);
+                }
+                try {
+                    dst.flush();
+                } catch (IOException e) {
+                    throw new HiveException("Failed to flush a file: "
+                            + dst.getFile().getAbsolutePath(), e);
+                }
+                if (logger.isInfoEnabled()) {
+                    File tmpFile = dst.getFile();
+                    logger.info("Wrote " + numTrainingExamples
+                            + " records to a temporary file for iterative 
training: "
+                            + tmpFile.getAbsolutePath() + " (" + 
FileUtils.prettyFileSize(tmpFile)
+                            + ")");
+                }
+
+                // run iterations
+                int iter = 2;
+                float perplexityPrev = Float.MAX_VALUE;
+                float perplexity;
+                int numTrain;
+                for (; iter <= iterations; iter++) {
+                    perplexity = 0.f;
+                    numTrain = 0;
+
+                    Arrays.fill(miniBatch, null); // clear
+                    miniBatchCount = 0;
+
+                    setCounterValue(iterCounter, iter);
+
+                    buf.clear();
+                    dst.resetPosition();
+                    while (true) {
+                        reportProgress(reporter);
+                        // TODO prefetch
+                        // writes training examples to a buffer in the 
temporary file
+                        final int bytesRead;
+                        try {
+                            bytesRead = dst.read(buf);
+                        } catch (IOException e) {
+                            throw new HiveException("Failed to read a file: "
+                                    + dst.getFile().getAbsolutePath(), e);
+                        }
+                        if (bytesRead == 0) { // reached file EOF
+                            break;
+                        }
+                        assert (bytesRead > 0) : bytesRead;
+
+                        // reads training examples from a buffer
+                        buf.flip();
+                        int remain = buf.remaining();
+                        if (remain < SizeOf.INT) {
+                            throw new HiveException("Illegal file format was 
detected");
+                        }
+                        while (remain >= SizeOf.INT) {
+                            int pos = buf.position();
+                            int recordBytes = buf.getInt();
+                            remain -= SizeOf.INT;
+                            if (remain < recordBytes) {
+                                buf.position(pos);
+                                break;
+                            }
+
+                            int wcLength = buf.getInt();
+                            final String[] wordCounts = new String[wcLength];
+                            for (int j = 0; j < wcLength; j++) {
+                                int len = buf.getInt();
+                                byte[] bytes = new byte[len];
+                                buf.get(bytes);
+                                wordCounts[j] = new String(bytes);
+                            }
+
+                            miniBatch[miniBatchCount] = wordCounts;
+                            miniBatchCount++;
+
+                            if (miniBatchCount == miniBatchSize) {
+                                model.train(miniBatch);
+                                perplexity += model.computePerplexity();
+                                numTrain++;
+
+                                Arrays.fill(miniBatch, null); // clear
+                                miniBatchCount = 0;
+                            }
+
+                            remain -= recordBytes;
+                        }
+                        buf.compact();
+                    }
+
+                    // update for remaining samples
+                    if (miniBatchCount > 0) { // update for remaining samples
+                        model.train(Arrays.copyOfRange(miniBatch, 0, 
miniBatchCount));
+                        perplexity += model.computePerplexity();
+                        numTrain++;
+                    }
+
+                    logger.info("Perplexity: " + perplexity + ", Num train: " 
+ numTrain);
+                    perplexity /= numTrain; // mean perplexity over `numTrain` 
mini-batches
+                    if (Math.abs(perplexityPrev - perplexity) < eps) {
+                        break;
+                    }
+                    perplexityPrev = perplexity;
+                }
+                logger.info("Performed "
+                        + Math.min(iter, iterations)
+                        + " iterations of "
+                        + NumberUtils.formatNumber(numTrainingExamples)
+                        + " training examples on a secondary storage (thus "
+                        + NumberUtils.formatNumber(numTrainingExamples * 
Math.min(iter, iterations))
+                        + " training updates in total)");
+            }
+        } finally {
+            // delete the temporary file and release resources
+            try {
+                dst.close(true);
+            } catch (IOException e) {
+                throw new HiveException("Failed to close a file: "
+                        + dst.getFile().getAbsolutePath(), e);
+            }
+            this.inputBuf = null;
+            this.fileIO = null;
+        }
+    }
+
+    protected void forwardModel() throws HiveException {
+        final IntWritable topicIdx = new IntWritable();
+        final Text word = new Text();
+        final FloatWritable score = new FloatWritable();
+
+        final Object[] forwardObjs = new Object[3];
+        forwardObjs[0] = topicIdx;
+        forwardObjs[1] = word;
+        forwardObjs[2] = score;
+
+        for (int k = 0; k < topic; k++) {
+            topicIdx.set(k);
+
+            final SortedMap<Float, List<String>> topicWords = 
model.getTopicWords(k);
+            for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) {
+                score.set(e.getKey());
+                List<String> words = e.getValue();
+                for (int i = 0; i < words.size(); i++) {
+                    word.set(words.get(i));
+                    forward(forwardObjs);
+                }
+            }
+        }
+
+        logger.info("Forwarded topic words each of " + topic + " topics");
+    }
+
+    /*
+     * For testing:
+     */
+
+    @VisibleForTesting
+    double getLambda(String label, int k) {
+        return model.getLambda(label, k);
+    }
+
+    @VisibleForTesting
+    SortedMap<Float, List<String>> getTopicWords(int k) {
+        return model.getTopicWords(k);
+    }
+
+    @VisibleForTesting
+    SortedMap<Float, List<String>> getTopicWords(int k, int topN) {
+        return model.getTopicWords(k, topN);
+    }
+
+    @VisibleForTesting
+    float[] getTopicDistribution(@Nonnull String[] doc) {
+        return model.getTopicDistribution(doc);
+    }
+}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9b2ddcc7/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java 
b/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
new file mode 100644
index 0000000..3e7ad10
--- /dev/null
+++ b/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
@@ -0,0 +1,554 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.topicmodel;
+
+import hivemall.annotations.VisibleForTesting;
+import hivemall.model.FeatureValue;
+import hivemall.utils.lang.ArrayUtils;
+import hivemall.utils.math.MathUtils;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.SortedMap;
+import java.util.TreeMap;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+import org.apache.commons.math3.distribution.GammaDistribution;
+import org.apache.commons.math3.special.Gamma;
+
+public final class OnlineLDAModel {
+
+    // number of topics
+    private final int _K;
+
+    // prior on weight vectors "theta ~ Dir(alpha_)"
+    private final float _alpha;
+
+    // prior on topics "beta"
+    private final float _eta;
+
+    // total number of documents
+    // in the truly online setting, this can be an estimate of the maximum 
number of documents that could ever seen
+    private long _D = -1L;
+
+    // defined by (tau0 + updateCount)^(-kappa_)
+    // controls how much old lambda is forgotten
+    private double _rhot;
+
+    // positive value which downweights early iterations
+    @Nonnegative
+    private final double _tau0;
+
+    // exponential decay rate (i.e., learning rate) which must be in (0.5, 1] 
to guarantee convergence
+    private final double _kappa;
+
+    // how many times EM steps are launched; later EM steps do not drastically 
forget old lambda
+    private long _updateCount = 0L;
+
+    // random number generator
+    @Nonnull
+    private final GammaDistribution _gd;
+    private static final double SHAPE = 100.d;
+    private static final double SCALE = 1.d / SHAPE;
+
+    // parameters
+    @Nonnull
+    private List<Map<String, float[]>> _phi;
+    private float[][] _gamma;
+    @Nonnull
+    private final Map<String, float[]> _lambda;
+
+    // check convergence in the expectation (E) step
+    private final double _delta;
+
+    @Nonnull
+    private final List<Map<String, Float>> _miniBatchMap;
+    private int _miniBatchSize;
+
+    // for computing perplexity
+    private float _docRatio = 1.f;
+    private long _wordCount = 0L;
+
+    public OnlineLDAModel(int K, float alpha, double delta) { // for E step 
only instantiation
+        this(K, alpha, 1 / 20.f, -1L, 1020, 0.7, delta);
+    }
+
+    public OnlineLDAModel(int K, float alpha, float eta, long D, double tau0, 
double kappa,
+            double delta) {
+        if (tau0 < 0.d) {
+            throw new IllegalArgumentException("tau0 MUST be positive: " + 
tau0);
+        }
+        if (kappa <= 0.5 || 1.d < kappa) {
+            throw new IllegalArgumentException("kappa MUST be in (0.5, 1.0]: " 
+ kappa);
+        }
+
+        this._K = K;
+        this._alpha = alpha;
+        this._eta = eta;
+        this._D = D;
+        this._tau0 = tau0;
+        this._kappa = kappa;
+        this._delta = delta;
+
+        // initialize a random number generator
+        this._gd = new GammaDistribution(SHAPE, SCALE);
+        _gd.reseedRandomGenerator(1001);
+
+        // initialize the parameters
+        this._lambda = new HashMap<String, float[]>(100);
+
+        this._miniBatchMap = new ArrayList<Map<String, Float>>();
+    }
+
+    /**
+     * In a truly online setting, total number of documents corresponds to the 
number of documents
+     * that have ever seen. In that case, users need to manually set the 
current max number of documents
+     * via this method.
+     * Note that, since the same set of documents could be repeatedly passed 
to `train()`,
+     * simply accumulating `_miniBatchSize`s as estimated `_D` is not 
sufficient.
+     */
+    public void setNumTotalDocs(@Nonnegative long D) {
+        this._D = D;
+    }
+
+    public void train(@Nonnull final String[][] miniBatch) {
+        if (_D <= 0L) {
+            throw new RuntimeException("Total number of documents MUST be set 
via `setNumTotalDocs()`");
+        }
+
+        preprocessMiniBatch(miniBatch);
+
+        initParams(true);
+
+        // Expectation
+        eStep();
+
+        this._rhot = Math.pow(_tau0 + _updateCount, -_kappa);
+
+        // Maximization
+        mStep();
+
+        _updateCount++;
+    }
+
+    private void preprocessMiniBatch(@Nonnull final String[][] miniBatch) {
+        initMiniBatchMap(miniBatch, _miniBatchMap);
+
+        this._miniBatchSize = _miniBatchMap.size();
+
+        // accumulate the number of words for each documents
+        this._wordCount = 0L;
+        for (int d = 0; d < _miniBatchSize; d++) {
+            for (float n : _miniBatchMap.get(d).values()) {
+                this._wordCount += n;
+            }
+        }
+
+        this._docRatio = (float)((double) _D / _miniBatchSize);
+    }
+
+    private static void initMiniBatchMap(@Nonnull final String[][] miniBatch,
+            @Nonnull final List<Map<String, Float>> map) {
+        map.clear();
+
+        final FeatureValue probe = new FeatureValue();
+
+        // parse document
+        for (final String[] e : miniBatch) {
+            if (e == null) {
+                continue;
+            }
+
+            final Map<String, Float> docMap = new HashMap<String, Float>();
+
+            // parse features
+            for (String fv : e) {
+                if (fv == null) {
+                    continue;
+                }
+                FeatureValue.parseFeatureAsString(fv, probe);
+                String label = probe.getFeatureAsString();
+                float value = probe.getValueAsFloat();
+                docMap.put(label, value);
+            }
+
+            map.add(docMap);
+        }
+    }
+
+    private void initParams(boolean gammaWithRandom) {
+        _phi = new ArrayList<Map<String, float[]>>();
+        _gamma = new float[_miniBatchSize][];
+
+        for (int d = 0; d < _miniBatchSize; d++) {
+            if (gammaWithRandom) {
+                _gamma[d] = ArrayUtils.newRandomFloatArray(_K, _gd);
+            } else {
+                _gamma[d] = ArrayUtils.newInstance(_K, 1.f);
+            }
+
+            final Map<String, float[]> phi_d = new HashMap<String, float[]>();
+            _phi.add(phi_d);
+            for (String label : _miniBatchMap.get(d).keySet()) {
+                phi_d.put(label, new float[_K]);
+                if (!_lambda.containsKey(label)) { // lambda for newly 
observed word
+                    _lambda.put(label, ArrayUtils.newRandomFloatArray(_K, 
_gd));
+                }
+            }
+        }
+    }
+
+    private void eStep() {
+        // since lambda is invariant in the expectation step,
+        // `digamma`s of lambda values for Elogbeta are pre-computed
+        final float[] lambdaSum = new float[_K];
+        final Map<String, float[]> digamma_lambda = new HashMap<String, 
float[]>();
+        for (Map.Entry<String, float[]> e : _lambda.entrySet()) {
+            String label = e.getKey();
+            float[] lambda_label = e.getValue();
+
+            // for digamma(lambdaSum)
+            MathUtils.add(lambdaSum, lambda_label, _K);
+
+            float[] digamma_lambda_label = new float[_K];
+            digamma_lambda.put(label, MathUtils.digamma(lambda_label));
+        }
+        final float[] digamma_lambdaSum = MathUtils.digamma(lambdaSum);
+
+        float[] gamma_d, gammaPrev_d;
+        Map<String, float[]> eLogBeta_d;
+
+        // for each of mini-batch documents, update gamma until convergence
+        for (int d = 0; d < _miniBatchSize; d++) {
+            gamma_d = _gamma[d];
+            eLogBeta_d = computeElogBetaPerDoc(d, digamma_lambda, 
digamma_lambdaSum);
+
+            do {
+                // (deep) copy the last gamma values
+                gammaPrev_d = gamma_d.clone();
+
+                updatePhiPerDoc(d, eLogBeta_d);
+                updateGammaPerDoc(d);
+            } while (!checkGammaDiff(gammaPrev_d, gamma_d));
+        }
+    }
+
+    @Nonnull
+    private Map<String, float[]> computeElogBetaPerDoc(@Nonnegative final int 
d,
+            @Nonnull Map<String, float[]> digamma_lambda, @Nonnull float[] 
digamma_lambdaSum) {
+        // Dirichlet expectation (2d) for lambda
+        final Map<String, float[]> eLogBeta_d = new HashMap<String, float[]>();
+        final Map<String, Float> doc = _miniBatchMap.get(d);
+
+        for (String label : doc.keySet()) {
+            float[] eLogBeta_label = eLogBeta_d.get(label);
+            if (eLogBeta_label == null) {
+                eLogBeta_label = new float[_K];
+                eLogBeta_d.put(label, eLogBeta_label);
+            }
+            final float[] digamma_lambda_label = digamma_lambda.get(label);
+            for (int k = 0; k < _K; k++) {
+                eLogBeta_label[k] = digamma_lambda_label[k] - 
digamma_lambdaSum[k];
+            }
+        }
+
+        return eLogBeta_d;
+    }
+
+    private void updatePhiPerDoc(@Nonnegative final int d, @Nonnull 
Map<String, float[]> eLogBeta_d) {
+        // Dirichlet expectation (2d) for gamma
+        final float[] eLogTheta_d = new float[_K];
+        final float[] gamma_d = _gamma[d];
+        final float digamma_gammaSum_d = (float) 
Gamma.digamma(MathUtils.sum(gamma_d));
+        for (int k = 0; k < _K; k++) {
+            eLogTheta_d[k] = (float) Gamma.digamma(gamma_d[k]) - 
digamma_gammaSum_d;
+        }
+
+        // updating phi w/ normalization
+        final Map<String, float[]> phi_d = _phi.get(d);
+        final Map<String, Float> doc = _miniBatchMap.get(d);
+        for (String label :  doc.keySet()) {
+            final float[] phi_label = phi_d.get(label);
+            final float[] eLogBeta_label = eLogBeta_d.get(label);
+
+            float normalizer = 0.f;
+            for (int k = 0; k < _K; k++) {
+                float phiVal = (float) Math.exp(eLogBeta_label[k] + 
eLogTheta_d[k]) + 1E-20f;
+                phi_label[k] = phiVal;
+                normalizer += phiVal;
+            }
+
+            // normalize
+            for (int k = 0; k < _K; k++) {
+                phi_label[k] /= normalizer;
+            }
+        }
+    }
+
+    private void updateGammaPerDoc(@Nonnegative final int d) {
+        final Map<String, Float> doc = _miniBatchMap.get(d);
+        final Map<String, float[]> phi_d = _phi.get(d);
+
+        final float[] gamma_d = _gamma[d];
+        for (int k = 0; k < _K; k++) {
+            gamma_d[k] = _alpha;
+        }
+        for (Map.Entry<String, Float> e : doc.entrySet()) {
+            final float[] phi_label = phi_d.get(e.getKey());
+            final float val = e.getValue();
+            for (int k = 0; k < _K; k++) {
+                gamma_d[k] += phi_label[k] * val;
+            }
+        }
+    }
+
+    private boolean checkGammaDiff(@Nonnull final float[] gammaPrev,
+            @Nonnull final float[] gammaNext) {
+        double diff = 0.d;
+        for (int k = 0; k < _K; k++) {
+            diff += Math.abs(gammaPrev[k] - gammaNext[k]);
+        }
+        return (diff / _K) < _delta;
+    }
+
+    private void mStep() {
+        // calculate lambdaTilde for vocabularies in the current mini-batch
+        final Map<String, float[]> lambdaTilde = new HashMap<String, 
float[]>();
+        for (int d = 0; d < _miniBatchSize; d++) {
+            final Map<String, float[]> phi_d = _phi.get(d);
+            for (String label : _miniBatchMap.get(d).keySet()) {
+                float[] lambdaTilde_label = lambdaTilde.get(label);
+                if (lambdaTilde_label == null) {
+                    lambdaTilde_label = ArrayUtils.newInstance(_K, _eta);
+                    lambdaTilde.put(label, lambdaTilde_label);
+                }
+
+                final float[] phi_label = phi_d.get(label);
+                for (int k = 0; k < _K; k++) {
+                    lambdaTilde_label[k] += _docRatio * phi_label[k];
+                }
+            }
+        }
+
+        // update lambda for all vocabularies
+        for (Map.Entry<String, float[]> e : _lambda.entrySet()) {
+            String label = e.getKey();
+            final float[] lambda_label = e.getValue();
+
+            float[] lambdaTilde_label = lambdaTilde.get(label);
+            if (lambdaTilde_label == null) {
+                lambdaTilde_label = ArrayUtils.newInstance(_K, _eta);
+            }
+
+            for (int k = 0; k < _K; k++) {
+                lambda_label[k] = (float) ((1.d - _rhot) * lambda_label[k] + 
_rhot
+                        * lambdaTilde_label[k]);
+            }
+        }
+    }
+
+    /**
+     * Calculate approximate perplexity for the current mini-batch.
+     */
+    public float computePerplexity() {
+        float bound = computeApproxBound();
+        float perWordBound = bound / (_docRatio * _wordCount);
+        return (float) Math.exp(-1.f * perWordBound);
+    }
+
+    /**
+     * Estimates the variational bound over all documents using only the 
documents passed as mini-batch.
+     */
+    private float computeApproxBound() {
+        float score = 0.f;
+
+        // prepare
+        final float[] gammaSum = new float[_miniBatchSize];
+        for (int d = 0; d < _miniBatchSize; d++) {
+            gammaSum[d] = MathUtils.sum(_gamma[d]);
+        }
+        final float[] digamma_gammaSum = MathUtils.digamma(gammaSum);
+
+        final float[] lambdaSum = new float[_K];
+        for (float[] lambda_label : _lambda.values()) {
+            MathUtils.add(lambdaSum, lambda_label, _K);
+        }
+        final float[] digamma_lambdaSum = MathUtils.digamma(lambdaSum);
+
+        final float logGamma_alpha = (float) Gamma.logGamma(_alpha);
+        final float logGamma_alphaSum = (float) Gamma.logGamma(_K * _alpha);
+
+        for (int d = 0; d < _miniBatchSize; d++) {
+            final float digamma_gammaSum_d = digamma_gammaSum[d];
+
+            // E[log p(doc | theta, beta)]
+            for (Map.Entry<String, Float> e : _miniBatchMap.get(d).entrySet()) 
{
+                final float[] lambda_label = _lambda.get(e.getKey());
+
+                // logsumexp( Elogthetad + Elogbetad )
+                final float[] temp = new float[_K];
+                float max = Float.MIN_VALUE;
+                for (int k = 0; k < _K; k++) {
+                    final float eLogTheta_dk = (float) 
Gamma.digamma(_gamma[d][k]) - digamma_gammaSum_d;
+                    final float eLogBeta_kw = (float) 
Gamma.digamma(lambda_label[k]) - digamma_lambdaSum[k];
+
+                    temp[k] = eLogTheta_dk + eLogBeta_kw;
+                    if (temp[k] > max) {
+                        max = temp[k];
+                    }
+                }
+                float logsumexp = 0.f;
+                for (int k = 0; k < _K; k++) {
+                    logsumexp += (float) Math.exp(temp[k] - max);
+                }
+                logsumexp = max + (float) Math.log(logsumexp);
+
+                // sum( word count * logsumexp(...) )
+                score += e.getValue() * logsumexp;
+            }
+
+            // E[log p(theta | alpha) - log q(theta | gamma)]
+            for (int k = 0; k < _K; k++) {
+                final float gamma_dk = _gamma[d][k];
+
+                // sum( (alpha - gammad) * Elogthetad )
+                score += (_alpha - gamma_dk)
+                        * ((float) Gamma.digamma(gamma_dk) - 
digamma_gammaSum_d);
+
+                // sum( gammaln(gammad) - gammaln(alpha) )
+                score += (float) Gamma.logGamma(gamma_dk) - logGamma_alpha;
+            }
+            score += logGamma_alphaSum; // gammaln(sum(alpha))
+            score -= Gamma.logGamma(gammaSum[d]); // gammaln(sum(gammad))
+        }
+
+        // assuming likelihood for when corpus in the documents is only a 
subset of the whole corpus
+        // (i.e., online setting); likelihood should be always roughly on the 
same scale
+        score *= _docRatio;
+
+        final float logGamma_eta = (float) Gamma.logGamma(_eta);
+        final float logGamma_etaSum = (float) Gamma.logGamma(_eta * 
_lambda.size()); // vocabulary size * eta
+
+        // E[log p(beta | eta) - log q (beta | lambda)]
+        for (float[] lambda_label : _lambda.values()) {
+            for (int k = 0; k < _K; k++) {
+                final float lambda_k = lambda_label[k];
+
+                // sum( (eta - lambda) * Elogbeta )
+                score += (_eta - lambda_k)
+                        * (float) (Gamma.digamma(lambda_k) - 
digamma_lambdaSum[k]);
+
+                // sum( gammaln(lambda) - gammaln(eta) )
+                score += (float) Gamma.logGamma(lambda_k) - logGamma_eta;
+            }
+        }
+        for (int k = 0; k < _K; k++) {
+            // sum( gammaln(etaSum) - gammaln( lambdaSum_k )
+            score += logGamma_etaSum - (float) Gamma.logGamma(lambdaSum[k]);
+        }
+
+        return score;
+    }
+
+    @VisibleForTesting
+    double getLambda(@Nonnull final String label, @Nonnegative final int k) {
+        final float[] lambda_label = _lambda.get(label);
+        if (lambda_label == null) {
+            throw new IllegalArgumentException("Word `" + label + "` is not in 
the corpus.");
+        }
+        if (k >= lambda_label.length) {
+            throw new IllegalArgumentException("Topic index must be in [0, "
+                    + _lambda.get(label).length + "]");
+        }
+        return lambda_label[k];
+    }
+
+    public void setLambda(@Nonnull final String label, @Nonnegative final int 
k, final float lambda_k) {
+        float[] lambda_label = _lambda.get(label);
+        if (lambda_label == null) {
+            lambda_label = ArrayUtils.newRandomFloatArray(_K, _gd);
+            _lambda.put(label, lambda_label);
+        }
+        lambda_label[k] = lambda_k;
+    }
+
+    @Nonnull
+    public SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int 
k) {
+        return getTopicWords(k, _lambda.keySet().size());
+    }
+
+    @Nonnull
+    public SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int 
k,
+            @Nonnegative int topN) {
+        float lambdaSum = 0.f;
+        final SortedMap<Float, List<String>> sortedLambda = new TreeMap<Float, 
List<String>>(
+            Collections.reverseOrder());
+
+        for (Map.Entry<String, float[]> e : _lambda.entrySet()) {
+            final float lambda_k = e.getValue()[k];
+            lambdaSum += lambda_k;
+
+            List<String> labels = sortedLambda.get(lambda_k);
+            if (labels == null) {
+                labels = new ArrayList<String>();
+                sortedLambda.put(lambda_k, labels);
+            }
+            labels.add(e.getKey());
+        }
+
+        final SortedMap<Float, List<String>> ret = new TreeMap<Float, 
List<String>>(
+            Collections.reverseOrder());
+
+        topN = Math.min(topN, _lambda.keySet().size());
+        int tt = 0;
+        for (Map.Entry<Float, List<String>> e : sortedLambda.entrySet()) {
+            ret.put(e.getKey() / lambdaSum, e.getValue());
+
+            if (++tt == topN) {
+                break;
+            }
+        }
+
+        return ret;
+    }
+
+    @Nonnull
+    public float[] getTopicDistribution(@Nonnull final String[] doc) {
+        preprocessMiniBatch(new String[][] {doc});
+
+        initParams(false);
+
+        eStep();
+
+        // normalize topic distribution
+        final float[] topicDistr = new float[_K];
+        final float[] gamma0 = _gamma[0];
+        final float gammaSum = MathUtils.sum(gamma0);
+        for (int k = 0; k < _K; k++) {
+            topicDistr[k] = gamma0[k] / gammaSum;
+        }
+        return topicDistr;
+    }
+
+}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9b2ddcc7/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/lang/ArrayUtils.java 
b/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
index e8e337d..711aac7 100644
--- a/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
+++ b/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
@@ -23,9 +23,12 @@ import java.util.Arrays;
 import java.util.List;
 import java.util.Random;
 
+import javax.annotation.Nonnegative;
 import javax.annotation.Nonnull;
 import javax.annotation.Nullable;
 
+import org.apache.commons.math3.distribution.GammaDistribution;
+
 public final class ArrayUtils {
 
     /**
@@ -715,4 +718,21 @@ public final class ArrayUtils {
         return cnt;
     }
 
+    @Nonnull
+    public static float[] newInstance(@Nonnegative int size, float 
filledValue) {
+        final float[] a = new float[size];
+        Arrays.fill(a, filledValue);
+        return a;
+    }
+    
+    @Nonnull
+    public static float[] newRandomFloatArray(@Nonnegative final int size,
+            @Nonnull final GammaDistribution gd) {
+        final float[] ret = new float[size];
+        for (int i = 0; i < size; i++) {
+            ret[i] = (float) gd.sample();
+        }
+        return ret;
+    }
+
 }

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9b2ddcc7/core/src/main/java/hivemall/utils/math/MathUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/math/MathUtils.java 
b/core/src/main/java/hivemall/utils/math/MathUtils.java
index b71d165..7fdea55 100644
--- a/core/src/main/java/hivemall/utils/math/MathUtils.java
+++ b/core/src/main/java/hivemall/utils/math/MathUtils.java
@@ -38,6 +38,9 @@ import java.util.Random;
 
 import javax.annotation.Nonnegative;
 import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
+import org.apache.commons.math3.special.Gamma;
 
 public final class MathUtils {
 
@@ -311,4 +314,44 @@ public final class MathUtils {
         return perm;
     }
 
+    public static float sum(@Nullable final float[] a) {
+        if (a == null) {
+            return 0.f;
+        }
+
+        float sum = 0.f;
+        for (float v : a) {
+            sum += v;
+        }
+        return sum;
+    }
+
+    public static float sum(@Nullable final float[] a, @Nonnegative final int 
size) {
+        if (a == null) {
+            return 0.f;
+        }
+
+        float sum = 0.f;
+        for (int i = 0; i < size; i++) {
+            sum += a[i];
+        }
+        return sum;
+    }
+
+    public static void add(@Nonnull final float[] dst, @Nonnull final float[] 
toAdd, final int size) {
+        for (int i = 0; i < size; i++) {
+            dst[i] += toAdd[i];
+        }
+    }
+
+    @Nonnull
+    public static float[] digamma(@Nonnull final float[] a) {
+        final int k = a.length;
+        final float[] ret = new float[k];
+        for (int i = 0; i < k; i++) {
+            ret[i] = (float) Gamma.digamma(a[i]);
+        }
+        return ret;
+    }
+
 }

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9b2ddcc7/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java 
b/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java
new file mode 100644
index 0000000..a23d917
--- /dev/null
+++ b/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java
@@ -0,0 +1,228 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.topicmodel;
+
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
+import org.apache.hadoop.hive.ql.udf.generic.SimpleGenericUDAFParameterInfo;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+
+import java.util.ArrayList;
+import java.util.Map;
+import java.util.HashMap;
+
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+public class LDAPredictUDAFTest {
+    LDAPredictUDAF udaf;
+    GenericUDAFEvaluator evaluator;
+    ObjectInspector[] inputOIs;
+    ObjectInspector[] partialOI;
+    LDAPredictUDAF.OnlineLDAPredictAggregationBuffer agg;
+
+    String[] words;
+    int[] labels;
+    float[] lambdas;
+
+    @Test(expected=UDFArgumentException.class)
+    public void testWithoutOption() throws Exception {
+        udaf = new LDAPredictUDAF();
+
+        inputOIs = new ObjectInspector[] {
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                        PrimitiveObjectInspector.PrimitiveCategory.STRING),
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                        PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                        PrimitiveObjectInspector.PrimitiveCategory.INT),
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                        PrimitiveObjectInspector.PrimitiveCategory.FLOAT)};
+
+        evaluator = udaf.getEvaluator(new 
SimpleGenericUDAFParameterInfo(inputOIs, false, false));
+
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+    }
+
+    @Test(expected=UDFArgumentException.class)
+    public void testWithoutTopicOption() throws Exception {
+        udaf = new LDAPredictUDAF();
+
+        inputOIs = new ObjectInspector[] {
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                        PrimitiveObjectInspector.PrimitiveCategory.STRING),
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                        PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                        PrimitiveObjectInspector.PrimitiveCategory.INT),
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                        PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+                ObjectInspectorUtils.getConstantObjectInspector(
+                        
PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-alpha 0.1")};
+
+        evaluator = udaf.getEvaluator(new 
SimpleGenericUDAFParameterInfo(inputOIs, false, false));
+
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+    }
+
+    @Before
+    public void setUp() throws Exception {
+        udaf = new LDAPredictUDAF();
+
+        inputOIs = new ObjectInspector[] {
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                        PrimitiveObjectInspector.PrimitiveCategory.STRING),
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                        PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                        PrimitiveObjectInspector.PrimitiveCategory.INT),
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                        PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+                ObjectInspectorUtils.getConstantObjectInspector(
+                        
PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topic 2")};
+
+        evaluator = udaf.getEvaluator(new 
SimpleGenericUDAFParameterInfo(inputOIs, false, false));
+
+        ArrayList<String> fieldNames = new ArrayList<String>();
+        ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
+
+        fieldNames.add("wcList");
+        fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(
+                PrimitiveObjectInspectorFactory.javaStringObjectInspector));
+
+        fieldNames.add("lambdaMap");
+        fieldOIs.add(ObjectInspectorFactory.getStandardMapObjectInspector(
+                PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+                ObjectInspectorFactory.getStandardListObjectInspector(
+                        
PrimitiveObjectInspectorFactory.javaFloatObjectInspector)));
+
+        fieldNames.add("topic");
+        
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+
+        fieldNames.add("alpha");
+        
fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
+
+        fieldNames.add("delta");
+        
fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+
+        partialOI = new ObjectInspector[4];
+        partialOI[0] = 
ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
+
+        agg = (LDAPredictUDAF.OnlineLDAPredictAggregationBuffer) 
evaluator.getNewAggregationBuffer();
+
+        words = new String[] {"fruits", "vegetables", "healthy", "flu", 
"apples", "oranges", "like", "avocados", "colds",
+            "colds", "avocados", "oranges", "like", "apples", "flu", 
"healthy", "vegetables", "fruits"};
+        labels = new int[] {0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 
1};
+        lambdas = new float[] {0.3339331f, 0.3324783f, 0.33209667f, 
3.2804057E-4f, 3.0303953E-4f, 2.4860457E-4f, 2.41481E-4f, 2.3554532E-4f, 
1.352576E-4f,
+            0.1660153f, 0.16596903f, 0.1659654f, 0.1659627f, 0.16593699f, 
0.1659259f, 0.0017611005f, 0.0015791848f, 8.84464E-4f};
+    }
+
+    @Test
+    public void test() throws Exception {
+        final Map<String, Float> doc1 = new HashMap<String, Float>();
+        doc1.put("fruits", 1.f);
+        doc1.put("healthy", 1.f);
+        doc1.put("vegetables", 1.f);
+
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+        evaluator.reset(agg);
+        for (int i = 0; i < words.length; i++) {
+            String word = words[i];
+            evaluator.iterate(agg, new Object[] {word, doc1.get(word), 
labels[i], lambdas[i]});
+        }
+        float[] doc1Distr = agg.get();
+
+        final Map<String, Float> doc2 = new HashMap<String, Float>();
+        doc2.put("apples", 1.f);
+        doc2.put("avocados", 1.f);
+        doc2.put("colds", 1.f);
+        doc2.put("flu", 1.f);
+        doc2.put("like", 2.f);
+        doc2.put("oranges", 1.f);
+
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+        evaluator.reset(agg);
+        for (int i = 0; i < words.length; i++) {
+            String word = words[i];
+            evaluator.iterate(agg, new Object[] {word, doc2.get(word), 
labels[i], lambdas[i]});
+        }
+        float[] doc2Distr = agg.get();
+
+        Assert.assertTrue(doc1Distr[0] > doc2Distr[0]);
+        Assert.assertTrue(doc1Distr[1] < doc2Distr[1]);
+    }
+
+
+    @Test
+    public void testMerge() throws Exception {
+        final Map<String, Float> doc = new HashMap<String, Float>();
+        doc.put("apples", 1.f);
+        doc.put("avocados", 1.f);
+        doc.put("colds", 1.f);
+        doc.put("flu", 1.f);
+        doc.put("like", 2.f);
+        doc.put("oranges", 1.f);
+
+        Object[] partials = new Object[3];
+
+        // bin #1
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+        evaluator.reset(agg);
+        for (int i = 0; i < 6; i++) {
+            evaluator.iterate(agg, new Object[]{words[i], doc.get(words[i]), 
labels[i], lambdas[i]});
+        }
+        partials[0] = evaluator.terminatePartial(agg);
+
+        // bin #2
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+        evaluator.reset(agg);
+        for (int i = 6; i < 12; i++) {
+            evaluator.iterate(agg, new Object[]{words[i], doc.get(words[i]), 
labels[i], lambdas[i]});
+        }
+        partials[1] = evaluator.terminatePartial(agg);
+
+        // bin #3
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+        evaluator.reset(agg);
+        for (int i = 12; i < 18; i++) {
+            evaluator.iterate(agg, new Object[]{words[i], doc.get(words[i]), 
labels[i], lambdas[i]});
+        }
+
+        partials[2] = evaluator.terminatePartial(agg);
+
+        // merge in a different order
+        final int[][] orders = new int[][] {{0, 1, 2}, {1, 0, 2}, {1, 2, 0}, 
{2, 1, 0}};
+        for (int i = 0; i < orders.length; i++) {
+            evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, partialOI);
+            evaluator.reset(agg);
+
+            evaluator.merge(agg, partials[orders[i][0]]);
+            evaluator.merge(agg, partials[orders[i][1]]);
+            evaluator.merge(agg, partials[orders[i][2]]);
+
+            float[] distr = agg.get();
+            Assert.assertTrue(distr[0] < distr[1]);
+        }
+    }
+}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9b2ddcc7/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java 
b/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java
new file mode 100644
index 0000000..d1e3f81
--- /dev/null
+++ b/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.topicmodel;
+
+import java.util.List;
+import java.util.Map;
+import java.util.SortedMap;
+import java.util.Arrays;
+
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
+import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class LDAUDTFTest {
+    private static final boolean DEBUG = false;
+
+    @Test
+    public void test() throws HiveException {
+        LDAUDTF udtf = new LDAUDTF();
+
+        ObjectInspector[] argOIs = new ObjectInspector[] {
+            
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector),
+            ObjectInspectorUtils.getConstantObjectInspector(
+                PrimitiveObjectInspectorFactory.javaStringObjectInspector, 
"-topic 2 -num_docs 2 -s 1")};
+
+        udtf.initialize(argOIs);
+
+        String[] doc1 = new String[]{"fruits:1", "healthy:1", "vegetables:1"};
+        String[] doc2 = new String[]{"apples:1", "avocados:1", "colds:1", 
"flu:1", "like:2", "oranges:1"};
+        for (int it = 0; it < 5; it++) {
+            udtf.process(new Object[]{ Arrays.asList(doc1) });
+            udtf.process(new Object[]{ Arrays.asList(doc2) });
+        }
+
+        SortedMap<Float, List<String>> topicWords;
+
+        println("Topic 0:");
+        println("========");
+        topicWords = udtf.getTopicWords(0);
+        for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) {
+            List<String> words = e.getValue();
+            for (int i = 0; i < words.size(); i++) {
+                println(e.getKey() + " " + words.get(i));
+            }
+        }
+        println("========");
+
+        println("Topic 1:");
+        println("========");
+        topicWords = udtf.getTopicWords(1);
+        for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) {
+            List<String> words = e.getValue();
+            for (int i = 0; i < words.size(); i++) {
+                println(e.getKey() + " " + words.get(i));
+            }
+        }
+        println("========");
+
+        int k1, k2;
+        float[] topicDistr = udtf.getTopicDistribution(doc1);
+        if (topicDistr[0] > topicDistr[1]) {
+            // topic 0 MUST represent doc#1
+            k1 = 0;
+            k2 = 1;
+        } else {
+            k1 = 1;
+            k2 = 0;
+        }
+
+        Assert.assertTrue("doc1 is in topic " + k1 + " (" + (topicDistr[k1] * 
100) + "%), "
+            + "and `vegetables` SHOULD be more suitable topic word than `flu` 
in the topic",
+            udtf.getLambda("vegetables", k1) > udtf.getLambda("flu", k1));
+        Assert.assertTrue("doc2 is in topic " + k2 + " (" + (topicDistr[k2] * 
100) + "%), "
+            + "and `avocados` SHOULD be more suitable topic word than 
`healthy` in the topic",
+            udtf.getLambda("avocados", k2) > udtf.getLambda("healthy", k2));
+    }
+
+    private static void println(String msg) {
+        if (DEBUG) {
+            System.out.println(msg);
+        }
+    }
+}

Reply via email to